From ab1fbfc0d6c20c911bb5bf3576939b2223bb5c0d Mon Sep 17 00:00:00 2001 From: Eric Traut Date: Sun, 10 Apr 2022 23:33:41 -0700 Subject: [PATCH] Enhanced reportUnnecessaryComparison diagnostic check so it also detects cases where a function appears within a condition expression. This is a common source of programming error. --- docs/configuration.md | 2 +- .../pyright-internal/src/analyzer/checker.ts | 47 +++++++++++++++++++ .../src/localization/localize.ts | 1 + .../src/localization/package.nls.en-us.json | 1 + .../src/tests/samples/comparison2.py | 37 +++++++++++++++ .../src/tests/typeEvaluator3.test.ts | 11 +++++ 6 files changed, 98 insertions(+), 1 deletion(-) create mode 100644 packages/pyright-internal/src/tests/samples/comparison2.py diff --git a/docs/configuration.md b/docs/configuration.md index 1a3e0be80..3d7d14173 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -142,7 +142,7 @@ The following settings control pyright’s diagnostic output (warnings or errors **reportUnnecessaryCast** [boolean or string, optional]: Generate or suppress diagnostics for 'cast' calls that are statically determined to be unnecessary. Such calls are sometimes indicative of a programming error. The default value for this setting is 'none'. -**reportUnnecessaryComparison** [boolean or string, optional]: Generate or suppress diagnostics for '==' or '!=' comparisons that are statically determined to always evaluate to False or True. Such comparisons are sometimes indicative of a programming error. The default value for this setting is 'none'. +**reportUnnecessaryComparison** [boolean or string, optional]: Generate or suppress diagnostics for '==' or '!=' comparisons or other conditional expressions that are statically determined to always evaluate to False or True. Such comparisons are sometimes indicative of a programming error. The default value for this setting is 'none'. **reportAssertAlwaysTrue** [boolean or string, optional]: Generate or suppress diagnostics for 'assert' statement that will provably always assert. This can be indicative of a programming error. The default value for this setting is 'warning'. diff --git a/packages/pyright-internal/src/analyzer/checker.ts b/packages/pyright-internal/src/analyzer/checker.ts index 5158e058e..8432ac582 100644 --- a/packages/pyright-internal/src/analyzer/checker.ts +++ b/packages/pyright-internal/src/analyzer/checker.ts @@ -36,6 +36,7 @@ import { DictionaryNode, ErrorNode, ExceptNode, + ExpressionNode, FormatStringNode, ForNode, FunctionNode, @@ -47,6 +48,7 @@ import { IndexNode, isExpressionNode, LambdaNode, + ListComprehensionIfNode, ListComprehensionNode, ListNode, MatchNode, @@ -760,13 +762,20 @@ export class Checker extends ParseTreeWalker { return true; } + override visitListComprehensionIf(node: ListComprehensionIfNode): boolean { + this._reportUnnecessaryConditionExpression(node.testExpression); + return true; + } + override visitIf(node: IfNode): boolean { this._evaluator.getType(node.testExpression); + this._reportUnnecessaryConditionExpression(node.testExpression); return true; } override visitWhile(node: WhileNode): boolean { this._evaluator.getType(node.testExpression); + this._reportUnnecessaryConditionExpression(node.testExpression); return true; } @@ -1113,6 +1122,7 @@ export class Checker extends ParseTreeWalker { override visitTernary(node: TernaryNode): boolean { this._evaluator.getType(node); + this._reportUnnecessaryConditionExpression(node.testExpression); return true; } @@ -1325,6 +1335,43 @@ export class Checker extends ParseTreeWalker { return false; } + private _reportUnnecessaryConditionExpression(expression: ExpressionNode) { + if (expression.nodeType === ParseNodeType.BinaryOperation) { + if (expression.operator === OperatorType.And || expression.operator === OperatorType.Or) { + this._reportUnnecessaryConditionExpression(expression.leftExpression); + this._reportUnnecessaryConditionExpression(expression.rightExpression); + } + + return; + } else if (expression.nodeType === ParseNodeType.UnaryOperation) { + if (expression.operator === OperatorType.Not) { + this._reportUnnecessaryConditionExpression(expression.expression); + } + + return; + } + + const exprTypeResult = this._evaluator.getTypeOfExpression(expression); + let isExprFunction = true; + + doForEachSubtype(exprTypeResult.type, (subtype) => { + subtype = this._evaluator.makeTopLevelTypeVarsConcrete(subtype); + + if (!isFunction(subtype) && !isOverloadedFunction(subtype)) { + isExprFunction = false; + } + }); + + if (isExprFunction) { + this._evaluator.addDiagnostic( + this._fileInfo.diagnosticRuleSet.reportUnnecessaryComparison, + DiagnosticRule.reportUnnecessaryComparison, + Localizer.Diagnostic.functionInConditionalExpression(), + expression + ); + } + } + private _reportUnusedExpression(node: ParseNode) { if (this._fileInfo.diagnosticRuleSet.reportUnusedExpression === 'none') { return; diff --git a/packages/pyright-internal/src/localization/localize.ts b/packages/pyright-internal/src/localization/localize.ts index 879f4eae2..76121f52d 100644 --- a/packages/pyright-internal/src/localization/localize.ts +++ b/packages/pyright-internal/src/localization/localize.ts @@ -426,6 +426,7 @@ export namespace Localizer { export const formatStringUnicode = () => getRawString('Diagnostic.formatStringUnicode'); export const formatStringUnterminated = () => getRawString('Diagnostic.formatStringUnterminated'); export const functionDecoratorTypeUnknown = () => getRawString('Diagnostic.functionDecoratorTypeUnknown'); + export const functionInConditionalExpression = () => getRawString('Diagnostic.functionInConditionalExpression'); export const generatorAsyncReturnType = () => getRawString('Diagnostic.generatorAsyncReturnType'); export const generatorNotParenthesized = () => getRawString('Diagnostic.generatorNotParenthesized'); export const generatorSyncReturnType = () => getRawString('Diagnostic.generatorSyncReturnType'); diff --git a/packages/pyright-internal/src/localization/package.nls.en-us.json b/packages/pyright-internal/src/localization/package.nls.en-us.json index 583d43206..41bf10c54 100644 --- a/packages/pyright-internal/src/localization/package.nls.en-us.json +++ b/packages/pyright-internal/src/localization/package.nls.en-us.json @@ -172,6 +172,7 @@ "formatStringIllegal": "Format string literals (f-strings) require Python 3.6 or newer", "formatStringUnterminated": "Unterminated expression in f-string; missing close brace", "functionDecoratorTypeUnknown": "Untyped function decorator obscures type of function; ignoring decorator", + "functionInConditionalExpression": "Conditional expression references function which always evaluates to True", "generatorAsyncReturnType": "Return type of async generator function must be \"AsyncGenerator\" or \"AsyncIterable\"", "generatorNotParenthesized": "Generator expressions must be parenthesized if not sole argument", "generatorSyncReturnType": "Return type of generator function must be \"Generator\" or \"Iterable\"", diff --git a/packages/pyright-internal/src/tests/samples/comparison2.py b/packages/pyright-internal/src/tests/samples/comparison2.py new file mode 100644 index 000000000..e60b764ec --- /dev/null +++ b/packages/pyright-internal/src/tests/samples/comparison2.py @@ -0,0 +1,37 @@ +# This sample tests the reportUnnecessaryComparison diagnostic check +# when applied to functions that appear within a conditional expression. + + +def cond() -> bool: + ... + + +# This should generate a diagnostic when reportUnnecessaryComparison is enabled. +if cond: + pass + +# This should generate a diagnostic when reportUnnecessaryComparison is enabled. +if 0 or cond: + pass + +# This should generate a diagnostic when reportUnnecessaryComparison is enabled. +if 1 and cond: + pass + +if cond(): + pass +# This should generate a diagnostic when reportUnnecessaryComparison is enabled. +elif cond: + pass + +# This should generate a diagnostic when reportUnnecessaryComparison is enabled. +def func1(): + while cond: + pass + + +# This should generate a diagnostic when reportUnnecessaryComparison is enabled. +a = [x for x in range(20) if cond] + +# This should generate a diagnostic when reportUnnecessaryComparison is enabled. +a = 1 if cond else 2 diff --git a/packages/pyright-internal/src/tests/typeEvaluator3.test.ts b/packages/pyright-internal/src/tests/typeEvaluator3.test.ts index 9c7a40596..a09b2b399 100644 --- a/packages/pyright-internal/src/tests/typeEvaluator3.test.ts +++ b/packages/pyright-internal/src/tests/typeEvaluator3.test.ts @@ -933,6 +933,17 @@ test('Comparison1', () => { TestUtils.validateResults(analysisResults2, 7); }); +test('Comparison2', () => { + const configOptions = new ConfigOptions('.'); + + const analysisResults1 = TestUtils.typeAnalyzeSampleFiles(['comparison2.py'], configOptions); + TestUtils.validateResults(analysisResults1, 0); + + configOptions.diagnosticRuleSet.reportUnnecessaryComparison = 'error'; + const analysisResults2 = TestUtils.typeAnalyzeSampleFiles(['comparison2.py'], configOptions); + TestUtils.validateResults(analysisResults2, 7); +}); + test('EmptyContainers1', () => { const analysisResults = TestUtils.typeAnalyzeSampleFiles(['emptyContainers1.py']); TestUtils.validateResults(analysisResults, 3);