diff --git a/docs/configuration.md b/docs/configuration.md index c16c472e3..9e87b99a2 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -130,6 +130,8 @@ The following settings control pyright’s diagnostic output (warnings or errors **reportUnnecessaryCast** [boolean or string, optional]: Generate or suppress diagnostics for 'cast' calls that are statically determined to be unnecessary. Such calls are sometimes indicative of a programming error. The default value for this setting is 'none'. +**reportUnnecessaryComparison** [boolean or string, optional]: Generate or suppress diagnostics for '==' or '!=' comparisons that are statically determined to always evaluate to False or True. Such comparisons are sometimes indicative of a programming error. The default value for this setting is 'none'. + **reportAssertAlwaysTrue** [boolean or string, optional]: Generate or suppress diagnostics for 'assert' statement that will provably always assert. This can be indicative of a programming error. The default value for this setting is 'warning'. **reportSelfClsParameterName** [boolean or string, optional]: Generate or suppress diagnostics for a missing or misnamed “self” parameter in instance methods and “cls” parameter in class methods. Instance methods in metaclasses (classes that derive from “type”) are allowed to use “cls” for instance methods. The default value for this setting is 'warning'. @@ -274,6 +276,7 @@ The following table lists the default severity levels for each diagnostic rule w | reportCallInDefaultInitializer | "none" | "none" | "none" | | reportUnnecessaryIsInstance | "none" | "none" | "error" | | reportUnnecessaryCast | "none" | "none" | "error" | +| reportUnnecessaryComparison | "none" | "none" | "error" | | reportAssertAlwaysTrue | "none" | "warning" | "error" | | reportSelfClsParameterName | "none" | "warning" | "error" | | reportImplicitStringConcatenation | "none" | "none" | "none" | diff --git a/packages/pyright-internal/src/analyzer/checker.ts b/packages/pyright-internal/src/analyzer/checker.ts index ebfedae20..96a74211e 100644 --- a/packages/pyright-internal/src/analyzer/checker.ts +++ b/packages/pyright-internal/src/analyzer/checker.ts @@ -92,6 +92,7 @@ import { isAnyOrUnknown, isClass, isFunction, + isModule, isNever, isNone, isObject, @@ -613,44 +614,6 @@ export class Checker extends ParseTreeWalker { } visitIf(node: IfNode): boolean { - // Check for expressions where a variable is being compared to - // a literal string or number. Look for a common bug where - // the comparison will always be False. Don't do this for - // expressions like 'sys.platform == "win32"' because those - // can change based on the execution environment and are therefore - // valid. - if ( - node.testExpression.nodeType === ParseNodeType.BinaryOperation && - node.testExpression.operator === OperatorType.Equals && - evaluateStaticBoolExpression(node.testExpression, this._fileInfo.executionEnvironment) === undefined - ) { - const rightType = this._evaluator.getType(node.testExpression.rightExpression); - if (rightType && isLiteralTypeOrUnion(rightType)) { - const leftType = this._evaluator.getType(node.testExpression.leftExpression); - if (leftType && isLiteralTypeOrUnion(leftType)) { - let isPossiblyTrue = false; - - doForEachSubtype(leftType, (leftSubtype) => { - if (this._evaluator.canAssignType(rightType, leftSubtype, new DiagnosticAddendum())) { - isPossiblyTrue = true; - } - }); - - if (!isPossiblyTrue) { - this._evaluator.addDiagnostic( - this._fileInfo.diagnosticRuleSet.reportGeneralTypeIssues, - DiagnosticRule.reportGeneralTypeIssues, - Localizer.Diagnostic.comparisonAlwaysFalse().format({ - leftType: this._evaluator.printType(leftType, /* expandTypeAlias */ true), - rightType: this._evaluator.printType(rightType, /* expandTypeAlias */ true), - }), - node.testExpression - ); - } - } - } - } - this._evaluator.getType(node.testExpression); return true; } @@ -928,6 +891,13 @@ export class Checker extends ParseTreeWalker { } visitBinaryOperation(node: BinaryOperationNode): boolean { + if (node.operator === OperatorType.Equals || node.operator === OperatorType.NotEquals) { + // Don't apply this rule if it's within an assert. + if (!ParseTreeUtils.isWithinAssertExpression(node)) { + this._validateComparisonTypes(node); + } + } + this._evaluator.getType(node); return true; } @@ -1068,6 +1038,169 @@ export class Checker extends ParseTreeWalker { return false; } + // Determines whether the types of the two operands for an == or != operation + // have overlapping types. + private _validateComparisonTypes(node: BinaryOperationNode) { + const leftType = this._evaluator.getType(node.leftExpression); + const rightType = this._evaluator.getType(node.rightExpression); + + if (!leftType || !rightType) { + return; + } + + // Check for the special case where the LHS and RHS are both literals. + if (isLiteralTypeOrUnion(rightType) && isLiteralTypeOrUnion(leftType)) { + if (evaluateStaticBoolExpression(node, this._fileInfo.executionEnvironment) === undefined) { + let isPossiblyTrue = false; + + doForEachSubtype(leftType, (leftSubtype) => { + if (this._evaluator.canAssignType(rightType, leftSubtype, new DiagnosticAddendum())) { + isPossiblyTrue = true; + } + }); + + if (!isPossiblyTrue) { + this._evaluator.addDiagnostic( + this._fileInfo.diagnosticRuleSet.reportUnnecessaryComparison, + DiagnosticRule.reportUnnecessaryComparison, + Localizer.Diagnostic.comparisonAlwaysFalse().format({ + leftType: this._evaluator.printType(leftType, /* expandTypeAlias */ true), + rightType: this._evaluator.printType(rightType, /* expandTypeAlias */ true), + }), + node + ); + } + } + } else { + let isComparable = false; + + doForEachSubtype(leftType, (leftSubtype) => { + if (isComparable) { + return; + } + + leftSubtype = transformTypeObjectToClass(leftSubtype); + leftSubtype = this._evaluator.makeTopLevelTypeVarsConcrete(leftSubtype); + doForEachSubtype(rightType, (rightSubtype) => { + if (isComparable) { + return; + } + + rightSubtype = transformTypeObjectToClass(rightSubtype); + rightSubtype = this._evaluator.makeTopLevelTypeVarsConcrete(rightSubtype); + + if (this._isTypeComparable(leftSubtype, rightSubtype)) { + isComparable = true; + } + }); + }); + + if (!isComparable) { + const leftTypeText = this._evaluator.printType(leftType, /* expandTypeAlias */ true); + const rightTypeText = this._evaluator.printType(rightType, /* expandTypeAlias */ true); + + const message = + node.operator === OperatorType.Equals + ? Localizer.Diagnostic.comparisonAlwaysFalse() + : Localizer.Diagnostic.comparisonAlwaysTrue(); + + this._evaluator.addDiagnostic( + this._fileInfo.diagnosticRuleSet.reportUnnecessaryComparison, + DiagnosticRule.reportUnnecessaryComparison, + message.format({ + leftType: leftTypeText, + rightType: rightTypeText, + }), + node + ); + } + } + } + + // Determines whether the two types are potentially comparable -- i.e. + // their types overlap in such a way that it makes sense for them to + // be compared with an == or != operator. + private _isTypeComparable(leftType: Type, rightType: Type) { + if (isAnyOrUnknown(leftType) || isAnyOrUnknown(rightType)) { + return true; + } + + if (isNever(leftType) || isNever(rightType)) { + return false; + } + + if (isModule(leftType) || isModule(rightType)) { + return !isTypeSame(leftType, rightType); + } + + if (isNone(leftType) || isNone(rightType)) { + return !isTypeSame(leftType, rightType); + } + + if (isClass(leftType)) { + if (isClass(rightType)) { + const genericLeftType = ClassType.cloneForSpecialization( + leftType, + /* typeArguments */ undefined, + /* isTypeArgumentExplicit */ false + ); + const genericRightType = ClassType.cloneForSpecialization( + rightType, + /* typeArguments */ undefined, + /* isTypeArgumentExplicit */ false + ); + + if ( + this._evaluator.canAssignType(genericLeftType, genericRightType, new DiagnosticAddendum()) || + this._evaluator.canAssignType(genericRightType, genericLeftType, new DiagnosticAddendum()) + ) { + return true; + } + } + + // Does the class have an operator overload for eq? + const metaclass = leftType.details.effectiveMetaclass; + if (metaclass && isClass(metaclass)) { + if (lookUpClassMember(metaclass, '__eq__', ClassMemberLookupFlags.SkipObjectBaseClass)) { + return true; + } + } + + return false; + } + + if (isObject(leftType)) { + if (isObject(rightType)) { + const genericLeftType = ClassType.cloneForSpecialization( + leftType.classType, + /* typeArguments */ undefined, + /* isTypeArgumentExplicit */ false + ); + const genericRightType = ClassType.cloneForSpecialization( + rightType.classType, + /* typeArguments */ undefined, + /* isTypeArgumentExplicit */ false + ); + + if ( + this._evaluator.canAssignType(genericLeftType, genericRightType, new DiagnosticAddendum()) || + this._evaluator.canAssignType(genericRightType, genericLeftType, new DiagnosticAddendum()) + ) { + return true; + } + } + + // Does the class have an operator overload for eq? + if (lookUpClassMember(leftType.classType, '__eq__', ClassMemberLookupFlags.SkipObjectBaseClass)) { + return true; + } + + return false; + } + + return true; + } + // Determines whether the specified type is one that should trigger // an "unused" value diagnostic. private _isTypeValidForUnusedValueTest(type: Type) { diff --git a/packages/pyright-internal/src/analyzer/parseTreeUtils.ts b/packages/pyright-internal/src/analyzer/parseTreeUtils.ts index 887cd4ee1..607d29f09 100644 --- a/packages/pyright-internal/src/analyzer/parseTreeUtils.ts +++ b/packages/pyright-internal/src/analyzer/parseTreeUtils.ts @@ -1039,6 +1039,24 @@ export function isWithinTryBlock(node: ParseNode): boolean { return false; } +export function isWithinAssertExpression(node: ParseNode): boolean { + let curNode: ParseNode | undefined = node; + let prevNode: ParseNode | undefined; + + while (curNode) { + switch (curNode.nodeType) { + case ParseNodeType.Assert: { + return curNode.testExpression === prevNode; + } + } + + prevNode = curNode; + curNode = curNode.parent; + } + + return false; +} + export function getDocString(statements: StatementNode[]): string | undefined { // See if the first statement in the suite is a triple-quote string. if (statements.length === 0) { diff --git a/packages/pyright-internal/src/common/configOptions.ts b/packages/pyright-internal/src/common/configOptions.ts index b00a59d71..f62c9b120 100644 --- a/packages/pyright-internal/src/common/configOptions.ts +++ b/packages/pyright-internal/src/common/configOptions.ts @@ -223,6 +223,9 @@ export interface DiagnosticRuleSet { // to always unnecessary. reportUnnecessaryCast: DiagnosticLevel; + // Report == or != operators that always evaluate to True or False. + reportUnnecessaryComparison: DiagnosticLevel; + // Report assert expressions that will always evaluate to true. reportAssertAlwaysTrue: DiagnosticLevel; @@ -316,6 +319,7 @@ export function getDiagLevelDiagnosticRules() { DiagnosticRule.reportCallInDefaultInitializer, DiagnosticRule.reportUnnecessaryIsInstance, DiagnosticRule.reportUnnecessaryCast, + DiagnosticRule.reportUnnecessaryComparison, DiagnosticRule.reportAssertAlwaysTrue, DiagnosticRule.reportSelfClsParameterName, DiagnosticRule.reportImplicitStringConcatenation, @@ -385,6 +389,7 @@ export function getOffDiagnosticRuleSet(): DiagnosticRuleSet { reportCallInDefaultInitializer: 'none', reportUnnecessaryIsInstance: 'none', reportUnnecessaryCast: 'none', + reportUnnecessaryComparison: 'none', reportAssertAlwaysTrue: 'none', reportSelfClsParameterName: 'none', reportImplicitStringConcatenation: 'none', @@ -450,6 +455,7 @@ export function getBasicDiagnosticRuleSet(): DiagnosticRuleSet { reportCallInDefaultInitializer: 'none', reportUnnecessaryIsInstance: 'none', reportUnnecessaryCast: 'none', + reportUnnecessaryComparison: 'none', reportAssertAlwaysTrue: 'warning', reportSelfClsParameterName: 'warning', reportImplicitStringConcatenation: 'none', @@ -515,6 +521,7 @@ export function getStrictDiagnosticRuleSet(): DiagnosticRuleSet { reportCallInDefaultInitializer: 'none', reportUnnecessaryIsInstance: 'error', reportUnnecessaryCast: 'error', + reportUnnecessaryComparison: 'error', reportAssertAlwaysTrue: 'error', reportSelfClsParameterName: 'error', reportImplicitStringConcatenation: 'none', @@ -1129,6 +1136,13 @@ export class ConfigOptions { defaultSettings.reportUnnecessaryCast ), + // Read the "reportUnnecessaryComparison" entry. + reportUnnecessaryComparison: this._convertDiagnosticLevel( + configObj.reportUnnecessaryComparison, + DiagnosticRule.reportUnnecessaryComparison, + defaultSettings.reportUnnecessaryComparison + ), + // Read the "reportAssertAlwaysTrue" entry. reportAssertAlwaysTrue: this._convertDiagnosticLevel( configObj.reportAssertAlwaysTrue, diff --git a/packages/pyright-internal/src/common/diagnosticRules.ts b/packages/pyright-internal/src/common/diagnosticRules.ts index 86335d7c8..afdbb2a90 100644 --- a/packages/pyright-internal/src/common/diagnosticRules.ts +++ b/packages/pyright-internal/src/common/diagnosticRules.ts @@ -57,6 +57,7 @@ export enum DiagnosticRule { reportCallInDefaultInitializer = 'reportCallInDefaultInitializer', reportUnnecessaryIsInstance = 'reportUnnecessaryIsInstance', reportUnnecessaryCast = 'reportUnnecessaryCast', + reportUnnecessaryComparison = 'reportUnnecessaryComparison', reportAssertAlwaysTrue = 'reportAssertAlwaysTrue', reportSelfClsParameterName = 'reportSelfClsParameterName', reportImplicitStringConcatenation = 'reportImplicitStringConcatenation', diff --git a/packages/pyright-internal/src/localization/localize.ts b/packages/pyright-internal/src/localization/localize.ts index 2c0e62f71..68c7979d7 100644 --- a/packages/pyright-internal/src/localization/localize.ts +++ b/packages/pyright-internal/src/localization/localize.ts @@ -228,6 +228,10 @@ export namespace Localizer { new ParameterizedString<{ leftType: string; rightType: string }>( getRawString('Diagnostic.comparisonAlwaysFalse') ); + export const comparisonAlwaysTrue = () => + new ParameterizedString<{ leftType: string; rightType: string }>( + getRawString('Diagnostic.comparisonAlwaysTrue') + ); export const comprehensionInDict = () => getRawString('Diagnostic.comprehensionInDict'); export const comprehensionInSet = () => getRawString('Diagnostic.comprehensionInSet'); export const concatenateParamSpecMissing = () => getRawString('Diagnostic.concatenateParamSpecMissing'); diff --git a/packages/pyright-internal/src/localization/package.nls.en-us.json b/packages/pyright-internal/src/localization/package.nls.en-us.json index 7217f090b..a8ffa7530 100644 --- a/packages/pyright-internal/src/localization/package.nls.en-us.json +++ b/packages/pyright-internal/src/localization/package.nls.en-us.json @@ -47,6 +47,7 @@ "classVarTooManyArgs": "Expected only one type argument after \"ClassVar\"", "clsSelfParamTypeMismatch": "Type of parameter \"{name}\" must be a supertype of its class \"{classType}\"", "comparisonAlwaysFalse": "Condition will always evaluate to False since the types \"{leftType}\" and \"{rightType}\" have no overlap", + "comparisonAlwaysTrue": "Condition will always evaluate to True since the types \"{leftType}\" and \"{rightType}\" have no overlap", "comprehensionInDict": "Comprehension cannot be used with other dictionary entries", "comprehensionInSet": "Comprehension cannot be used with other set entries", "concatenateParamSpecMissing": "Last type argument for \"Concatenate\" must be a ParamSpec", diff --git a/packages/pyright-internal/src/tests/samples/comparison1.py b/packages/pyright-internal/src/tests/samples/comparison1.py index 92cb66ffa..d5d7a5143 100644 --- a/packages/pyright-internal/src/tests/samples/comparison1.py +++ b/packages/pyright-internal/src/tests/samples/comparison1.py @@ -1,7 +1,7 @@ # This sample tests the check for non-overlapping types compared # with equals comparison. -from typing import Literal +from typing import Literal, TypeVar, Union OS = Literal["Linux", "Darwin", "Windows"] @@ -16,7 +16,7 @@ def func1(os: OS, val: Literal[1, "linux"]): return False # This should generate an error because there is no overlap in types. - if os == val: + if os != val: return False # This should generate an error because there is no overlap in types. @@ -25,3 +25,31 @@ def func1(os: OS, val: Literal[1, "linux"]): if val == 1: return True + +class ClassA: ... +class ClassB: ... + +_T1 = TypeVar("_T1") +_T2 = TypeVar("_T2", bound=ClassB) + +def func2(a: ClassA, b: ClassB, c: _T1, d: _T2, e: Union[ClassA, ClassB]) -> Union[None, _T1, _T2]: + # This should generate an error because there is no overlap in types. + if a == b: + return + + # This should generate an error because there is no overlap in types. + if a != b: + return + + if a != c: + return + + # This should generate an error because there is no overlap in types. + if a != d: + return + + if a == e: + return + + if b == e: + return diff --git a/packages/pyright-internal/src/tests/typeEvaluator1.test.ts b/packages/pyright-internal/src/tests/typeEvaluator1.test.ts index 7949bec04..04a5d5ead 100644 --- a/packages/pyright-internal/src/tests/typeEvaluator1.test.ts +++ b/packages/pyright-internal/src/tests/typeEvaluator1.test.ts @@ -1425,8 +1425,14 @@ test('List1', () => { }); test('Comparison1', () => { - const analysisResults = TestUtils.typeAnalyzeSampleFiles(['comparison1.py']); - TestUtils.validateResults(analysisResults, 3); + const configOptions = new ConfigOptions('.'); + + const analysisResults1 = TestUtils.typeAnalyzeSampleFiles(['comparison1.py'], configOptions); + TestUtils.validateResults(analysisResults1, 0); + + configOptions.diagnosticRuleSet.reportUnnecessaryComparison = 'error'; + const analysisResults2 = TestUtils.typeAnalyzeSampleFiles(['comparison1.py'], configOptions); + TestUtils.validateResults(analysisResults2, 6); }); test('EmptyContainers1', () => { diff --git a/packages/vscode-pyright/package.json b/packages/vscode-pyright/package.json index a5764fbd9..00691c792 100644 --- a/packages/vscode-pyright/package.json +++ b/packages/vscode-pyright/package.json @@ -554,6 +554,17 @@ "error" ] }, + "reportUnnecessaryComparison": { + "type": "string", + "description": "Diagnostics for '==' and '!=' comparisons that are statically determined to be unnecessary. Such calls are sometimes indicative of a programming error.", + "default": "none", + "enum": [ + "none", + "information", + "warning", + "error" + ] + }, "reportAssertAlwaysTrue": { "type": "string", "description": "Diagnostics for 'assert' statement that will provably always assert. This can be indicative of a programming error.", diff --git a/packages/vscode-pyright/schemas/pyrightconfig.schema.json b/packages/vscode-pyright/schemas/pyrightconfig.schema.json index 50136b406..6bd385858 100644 --- a/packages/vscode-pyright/schemas/pyrightconfig.schema.json +++ b/packages/vscode-pyright/schemas/pyrightconfig.schema.json @@ -371,6 +371,12 @@ "title": "Controls reporting calls to 'cast' that are unnecessary", "default": "none" }, + "reportUnnecessaryComparison": { + "$id": "#/properties/reportUnnecessaryComparison", + "$ref": "#/definitions/diagnostic", + "title": "Controls reporting the use of '==' or '!=' comparisons that are unnecessary", + "default": "none" + }, "reportAssertAlwaysTrue": { "$id": "#/properties/reportAssertAlwaysTrue", "$ref": "#/definitions/diagnostic",