Improved special-case code for functools.totalordering so it enforces the operand type in the provided __lt__ method.

This commit is contained in:
Eric Traut 2022-05-18 09:12:20 -07:00
parent a83fef3c6f
commit e3b6e7b51c
3 changed files with 56 additions and 8 deletions

View File

@ -22,8 +22,9 @@ import {
isFunction,
isInstantiableClass,
OverloadedFunctionType,
Type,
} from './types';
import { ClassMemberLookupFlags, lookUpObjectMember, synthesizeTypeVarForSelfCls } from './typeUtils';
import { ClassMember, ClassMemberLookupFlags, lookUpObjectMember, synthesizeTypeVarForSelfCls } from './typeUtils';
export function applyFunctionTransform(
evaluator: TypeEvaluator,
@ -62,11 +63,16 @@ function applyTotalOrderingTransform(
const instanceType = ClassType.cloneAsInstance(classType);
// Verify that the class has at least one of the required functions.
let firstMemberFound: ClassMember | undefined;
const missingMethods = orderingMethods.filter((methodName) => {
return !lookUpObjectMember(instanceType, methodName, ClassMemberLookupFlags.SkipInstanceVariables);
const memberInfo = lookUpObjectMember(instanceType, methodName, ClassMemberLookupFlags.SkipInstanceVariables);
if (memberInfo && !firstMemberFound) {
firstMemberFound = memberInfo;
}
return !memberInfo;
});
if (missingMethods.length === orderingMethods.length) {
if (!firstMemberFound) {
evaluator.addDiagnostic(
getFileInfo(errorNode).diagnosticRuleSet.reportGeneralTypeIssues,
DiagnosticRule.reportGeneralTypeIssues,
@ -76,9 +82,26 @@ function applyTotalOrderingTransform(
return result;
}
const objectType = evaluator.getBuiltInObject(errorNode, 'object');
if (!objectType || !isClassInstance(objectType)) {
return result;
// Determine what type to use for the parameter corresponding to
// the second operand. This will be taken from the existing method.
let operandType: Type | undefined;
const firstMemberType = evaluator.getTypeOfMember(firstMemberFound);
if (
isFunction(firstMemberType) &&
firstMemberType.details.parameters.length >= 2 &&
firstMemberType.details.parameters[1].hasDeclaredType
) {
operandType = firstMemberType.details.parameters[1].type;
}
// If there was no provided operand type, fall back to object.
if (!operandType) {
const objectType = evaluator.getBuiltInObject(errorNode, 'object');
if (!objectType || !isClassInstance(objectType)) {
return result;
}
operandType = objectType;
}
const boolType = evaluator.getBuiltInObject(errorNode, 'bool');
@ -96,7 +119,7 @@ function applyTotalOrderingTransform(
const objParam: FunctionParameter = {
category: ParameterCategory.Simple,
name: '__value',
type: objectType,
type: operandType,
hasDeclaredType: true,
};

View File

@ -25,3 +25,28 @@ v6 = a != b
@total_ordering
class ClassB:
val1: int
@total_ordering
class ClassC:
def __eq__(self, other: object) -> bool:
return False
def __lt__(self, other: "ClassC") -> bool:
return False
reveal_type(ClassC() < ClassC(), expected_text="bool")
reveal_type(ClassC() <= ClassC(), expected_text="bool")
reveal_type(ClassC() == ClassC(), expected_text="bool")
reveal_type(ClassC() > ClassC(), expected_text="bool")
reveal_type(ClassC() >= ClassC(), expected_text="bool")
_ = ClassC() == 1
_ = ClassC() != 1
# The following four lines should each produce an error.
_ = ClassC() < 1
_ = ClassC() <= 1
_ = ClassC() > 1
_ = ClassC() >= 1

View File

@ -1243,7 +1243,7 @@ test('Partial2', () => {
test('TotalOrdering1', () => {
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['totalOrdering1.py']);
TestUtils.validateResults(analysisResults, 1);
TestUtils.validateResults(analysisResults, 5);
});
test('TupleUnpack1', () => {