mirror of
https://github.com/microsoft/pyright.git
synced 2024-10-26 10:55:06 +03:00
Improved special-case code for functools.totalordering so it enforces the operand type in the provided __lt__
method.
This commit is contained in:
parent
a83fef3c6f
commit
e3b6e7b51c
@ -22,8 +22,9 @@ import {
|
||||
isFunction,
|
||||
isInstantiableClass,
|
||||
OverloadedFunctionType,
|
||||
Type,
|
||||
} from './types';
|
||||
import { ClassMemberLookupFlags, lookUpObjectMember, synthesizeTypeVarForSelfCls } from './typeUtils';
|
||||
import { ClassMember, ClassMemberLookupFlags, lookUpObjectMember, synthesizeTypeVarForSelfCls } from './typeUtils';
|
||||
|
||||
export function applyFunctionTransform(
|
||||
evaluator: TypeEvaluator,
|
||||
@ -62,11 +63,16 @@ function applyTotalOrderingTransform(
|
||||
const instanceType = ClassType.cloneAsInstance(classType);
|
||||
|
||||
// Verify that the class has at least one of the required functions.
|
||||
let firstMemberFound: ClassMember | undefined;
|
||||
const missingMethods = orderingMethods.filter((methodName) => {
|
||||
return !lookUpObjectMember(instanceType, methodName, ClassMemberLookupFlags.SkipInstanceVariables);
|
||||
const memberInfo = lookUpObjectMember(instanceType, methodName, ClassMemberLookupFlags.SkipInstanceVariables);
|
||||
if (memberInfo && !firstMemberFound) {
|
||||
firstMemberFound = memberInfo;
|
||||
}
|
||||
return !memberInfo;
|
||||
});
|
||||
|
||||
if (missingMethods.length === orderingMethods.length) {
|
||||
if (!firstMemberFound) {
|
||||
evaluator.addDiagnostic(
|
||||
getFileInfo(errorNode).diagnosticRuleSet.reportGeneralTypeIssues,
|
||||
DiagnosticRule.reportGeneralTypeIssues,
|
||||
@ -76,9 +82,26 @@ function applyTotalOrderingTransform(
|
||||
return result;
|
||||
}
|
||||
|
||||
const objectType = evaluator.getBuiltInObject(errorNode, 'object');
|
||||
if (!objectType || !isClassInstance(objectType)) {
|
||||
return result;
|
||||
// Determine what type to use for the parameter corresponding to
|
||||
// the second operand. This will be taken from the existing method.
|
||||
let operandType: Type | undefined;
|
||||
|
||||
const firstMemberType = evaluator.getTypeOfMember(firstMemberFound);
|
||||
if (
|
||||
isFunction(firstMemberType) &&
|
||||
firstMemberType.details.parameters.length >= 2 &&
|
||||
firstMemberType.details.parameters[1].hasDeclaredType
|
||||
) {
|
||||
operandType = firstMemberType.details.parameters[1].type;
|
||||
}
|
||||
|
||||
// If there was no provided operand type, fall back to object.
|
||||
if (!operandType) {
|
||||
const objectType = evaluator.getBuiltInObject(errorNode, 'object');
|
||||
if (!objectType || !isClassInstance(objectType)) {
|
||||
return result;
|
||||
}
|
||||
operandType = objectType;
|
||||
}
|
||||
|
||||
const boolType = evaluator.getBuiltInObject(errorNode, 'bool');
|
||||
@ -96,7 +119,7 @@ function applyTotalOrderingTransform(
|
||||
const objParam: FunctionParameter = {
|
||||
category: ParameterCategory.Simple,
|
||||
name: '__value',
|
||||
type: objectType,
|
||||
type: operandType,
|
||||
hasDeclaredType: true,
|
||||
};
|
||||
|
||||
|
@ -25,3 +25,28 @@ v6 = a != b
|
||||
@total_ordering
|
||||
class ClassB:
|
||||
val1: int
|
||||
|
||||
|
||||
@total_ordering
|
||||
class ClassC:
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return False
|
||||
|
||||
def __lt__(self, other: "ClassC") -> bool:
|
||||
return False
|
||||
|
||||
|
||||
reveal_type(ClassC() < ClassC(), expected_text="bool")
|
||||
reveal_type(ClassC() <= ClassC(), expected_text="bool")
|
||||
reveal_type(ClassC() == ClassC(), expected_text="bool")
|
||||
reveal_type(ClassC() > ClassC(), expected_text="bool")
|
||||
reveal_type(ClassC() >= ClassC(), expected_text="bool")
|
||||
|
||||
_ = ClassC() == 1
|
||||
_ = ClassC() != 1
|
||||
|
||||
# The following four lines should each produce an error.
|
||||
_ = ClassC() < 1
|
||||
_ = ClassC() <= 1
|
||||
_ = ClassC() > 1
|
||||
_ = ClassC() >= 1
|
||||
|
@ -1243,7 +1243,7 @@ test('Partial2', () => {
|
||||
test('TotalOrdering1', () => {
|
||||
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['totalOrdering1.py']);
|
||||
|
||||
TestUtils.validateResults(analysisResults, 1);
|
||||
TestUtils.validateResults(analysisResults, 5);
|
||||
});
|
||||
|
||||
test('TupleUnpack1', () => {
|
||||
|
Loading…
Reference in New Issue
Block a user