mirror of
https://github.com/microsoft/pyright.git
synced 2024-10-03 19:37:39 +03:00
Fixed validation of __init__ and __new__.
This commit is contained in:
parent
21eb57f64e
commit
45d3682eeb
@ -53,6 +53,22 @@ export enum EvaluatorFlags {
|
||||
ConvertSpecialTypes = 2
|
||||
}
|
||||
|
||||
export enum MemberAccessFlags {
|
||||
None = 0,
|
||||
|
||||
// By default, both class and instance members are considered.
|
||||
// Set this flag to skip the instance members.
|
||||
SkipInstanceMembers = 1,
|
||||
|
||||
// By default, members of base classes are also searched.
|
||||
// Set this flag to consider only the specified class' members.
|
||||
SkipBaseClasses = 2,
|
||||
|
||||
// By default, if the class has a __getattribute__ or __getattr__
|
||||
// magic method, it is assumed to have any member.
|
||||
SkipGetAttributeCheck = 4
|
||||
}
|
||||
|
||||
interface ParamAssignmentInfo {
|
||||
argsNeeded: number;
|
||||
argsReceived: number;
|
||||
@ -83,7 +99,7 @@ export class ExpressionEvaluator {
|
||||
}
|
||||
|
||||
getTypeFromClassMember(memberName: string, classType: ClassType): Type | undefined {
|
||||
return this._getTypeFromClassMemberString(memberName, classType, false);
|
||||
return this._getTypeFromClassMemberString(memberName, classType, MemberAccessFlags.None);
|
||||
}
|
||||
|
||||
// Determines if the function node is a property accessor (getter, setter, deleter).
|
||||
@ -351,10 +367,12 @@ export class ExpressionEvaluator {
|
||||
isClassMember = true;
|
||||
isObjectMember = true;
|
||||
} else if (baseType instanceof ClassType) {
|
||||
type = this._getTypeFromClassMemberAccess(node.memberName, baseType, false);
|
||||
type = this._getTypeFromClassMemberAccess(node.memberName,
|
||||
baseType, MemberAccessFlags.SkipInstanceMembers);
|
||||
isClassMember = true;
|
||||
} else if (baseType instanceof ObjectType) {
|
||||
type = this._getTypeFromClassMemberAccess(node.memberName, baseType.getClassType(), true);
|
||||
type = this._getTypeFromClassMemberAccess(node.memberName, baseType.getClassType(),
|
||||
MemberAccessFlags.None);
|
||||
isObjectMember = true;
|
||||
} else if (baseType instanceof ModuleType) {
|
||||
let memberInfo = baseType.getFields().get(memberName);
|
||||
@ -422,11 +440,10 @@ export class ExpressionEvaluator {
|
||||
// A wrapper around _getTypeFromClassMemberString that reports
|
||||
// errors if the member name is not found.
|
||||
private _getTypeFromClassMemberAccess(memberNameNode: NameNode,
|
||||
classType: ClassType, includeInstanceMembers: boolean) {
|
||||
classType: ClassType, flags: MemberAccessFlags) {
|
||||
|
||||
const memberName = memberNameNode.nameToken.value;
|
||||
let type = this._getTypeFromClassMemberString(memberName,
|
||||
classType, includeInstanceMembers);
|
||||
let type = this._getTypeFromClassMemberString(memberName, classType, flags);
|
||||
|
||||
if (type) {
|
||||
return type;
|
||||
@ -445,7 +462,7 @@ export class ExpressionEvaluator {
|
||||
}
|
||||
|
||||
private _getTypeFromClassMemberString(memberName: string, classType: ClassType,
|
||||
includeInstanceMembers: boolean): Type | undefined {
|
||||
flags: MemberAccessFlags): Type | undefined {
|
||||
|
||||
// Build a map of type parameters and the type arguments associated with them.
|
||||
let typeArgMap = new TypeVarMap();
|
||||
@ -471,7 +488,9 @@ export class ExpressionEvaluator {
|
||||
typeArgMap.set(typeVarName, typeArgType);
|
||||
});
|
||||
|
||||
let memberInfo = TypeUtils.lookUpClassMember(classType, memberName, includeInstanceMembers);
|
||||
let memberInfo = TypeUtils.lookUpClassMember(classType, memberName,
|
||||
!(flags & MemberAccessFlags.SkipInstanceMembers),
|
||||
!(flags & MemberAccessFlags.SkipBaseClasses));
|
||||
if (memberInfo) {
|
||||
let type = TypeUtils.getEffectiveTypeOfMember(memberInfo);
|
||||
if (type instanceof PropertyType) {
|
||||
@ -481,31 +500,33 @@ export class ExpressionEvaluator {
|
||||
return this._specializeType(type, typeArgMap);
|
||||
}
|
||||
|
||||
// See if the class has a "__getattribute__" or "__getattr__" method.
|
||||
// If so, aribrary members are supported.
|
||||
let getAttribMember = TypeUtils.lookUpClassMember(classType, '__getattribute__');
|
||||
if (getAttribMember && getAttribMember.class) {
|
||||
const isObjectClass = getAttribMember.class.isBuiltIn() &&
|
||||
getAttribMember.class.getClassName() === 'object';
|
||||
// The built-in 'object' class, from which every class derives,
|
||||
// implements the default __getattribute__ method. We want to ignore
|
||||
// this one. If this method is overridden, we need to assume that
|
||||
// all members can be accessed.
|
||||
if (!isObjectClass) {
|
||||
const getAttribType = TypeUtils.getEffectiveTypeOfMember(getAttribMember);
|
||||
if (getAttribType instanceof FunctionType) {
|
||||
return this._specializeType(
|
||||
getAttribType.getEffectiveReturnType(), typeArgMap);
|
||||
if (!(flags & MemberAccessFlags.SkipGetAttributeCheck)) {
|
||||
// See if the class has a "__getattribute__" or "__getattr__" method.
|
||||
// If so, aribrary members are supported.
|
||||
let getAttribMember = TypeUtils.lookUpClassMember(classType, '__getattribute__', false);
|
||||
if (getAttribMember && getAttribMember.class) {
|
||||
const isObjectClass = getAttribMember.class.isBuiltIn() &&
|
||||
getAttribMember.class.getClassName() === 'object';
|
||||
// The built-in 'object' class, from which every class derives,
|
||||
// implements the default __getattribute__ method. We want to ignore
|
||||
// this one. If this method is overridden, we need to assume that
|
||||
// all members can be accessed.
|
||||
if (!isObjectClass) {
|
||||
const getAttribType = TypeUtils.getEffectiveTypeOfMember(getAttribMember);
|
||||
if (getAttribType instanceof FunctionType) {
|
||||
return this._specializeType(
|
||||
getAttribType.getEffectiveReturnType(), typeArgMap);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let getAttrMember = TypeUtils.lookUpClassMember(classType, '__getattr__');
|
||||
if (getAttrMember) {
|
||||
const getAttrType = TypeUtils.getEffectiveTypeOfMember(getAttrMember);
|
||||
if (getAttrType instanceof FunctionType) {
|
||||
return this._specializeType(
|
||||
getAttrType.getEffectiveReturnType(), typeArgMap);
|
||||
let getAttrMember = TypeUtils.lookUpClassMember(classType, '__getattr__', false);
|
||||
if (getAttrMember) {
|
||||
const getAttrType = TypeUtils.getEffectiveTypeOfMember(getAttrMember);
|
||||
if (getAttrType instanceof FunctionType) {
|
||||
return this._specializeType(
|
||||
getAttrType.getEffectiveReturnType(), typeArgMap);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -651,9 +672,8 @@ export class ExpressionEvaluator {
|
||||
|
||||
// Assume this is a call to the constructor.
|
||||
if (!type) {
|
||||
if (this._validateConstructorArguments(node, callType)) {
|
||||
type = new ObjectType(callType);
|
||||
}
|
||||
this._validateConstructorArguments(node, callType);
|
||||
type = new ObjectType(callType);
|
||||
}
|
||||
} else if (callType instanceof FunctionType) {
|
||||
// The stdlib collections/__init__.pyi stub file defines namedtuple
|
||||
@ -690,7 +710,7 @@ export class ExpressionEvaluator {
|
||||
type = UnknownType.create();
|
||||
} else if (callType instanceof ObjectType) {
|
||||
let memberType = this._getTypeFromClassMemberString(
|
||||
'__call__', callType.getClassType(), false);
|
||||
'__call__', callType.getClassType(), MemberAccessFlags.SkipGetAttributeCheck);
|
||||
if (memberType && memberType instanceof FunctionType) {
|
||||
if (this._validateCallArguments(node, memberType, true)) {
|
||||
type = memberType.getEffectiveReturnType();
|
||||
@ -721,6 +741,9 @@ export class ExpressionEvaluator {
|
||||
}
|
||||
} else if (callType.isAny()) {
|
||||
type = UnknownType.create();
|
||||
} else if (callType instanceof TypeVarType) {
|
||||
// TODO - remove once we support generics
|
||||
type = UnknownType.create();
|
||||
}
|
||||
|
||||
if (type === undefined) {
|
||||
@ -753,31 +776,29 @@ export class ExpressionEvaluator {
|
||||
}
|
||||
|
||||
// Tries to match the arguments of a call to the constructor for a class.
|
||||
private _validateConstructorArguments(node: CallExpressionNode, type: ClassType): boolean {
|
||||
let isValid = false;
|
||||
private _validateConstructorArguments(node: CallExpressionNode, type: ClassType): void {
|
||||
let validatedTypes = false;
|
||||
|
||||
let initMethodType = this._getTypeFromClassMemberString('__init__', type, false);
|
||||
let initMethodType = this._getTypeFromClassMemberString('__init__', type,
|
||||
MemberAccessFlags.SkipGetAttributeCheck | MemberAccessFlags.SkipInstanceMembers);
|
||||
if (initMethodType) {
|
||||
isValid = this._validateCallArguments(node, initMethodType, true);
|
||||
this._validateCallArguments(node, initMethodType, true);
|
||||
validatedTypes = true;
|
||||
}
|
||||
|
||||
if (!validatedTypes) {
|
||||
// If there's no init method, check for a constructor.
|
||||
let constructorMethodType = this._getTypeFromClassMemberString('__new__', type, false);
|
||||
if (constructorMethodType) {
|
||||
isValid = this._validateCallArguments(node, constructorMethodType, true);
|
||||
validatedTypes = true;
|
||||
}
|
||||
// If there's no init method, check for a constructor.
|
||||
let constructorMethodType = this._getTypeFromClassMemberString('__new__', type,
|
||||
MemberAccessFlags.SkipGetAttributeCheck | MemberAccessFlags.SkipInstanceMembers |
|
||||
MemberAccessFlags.SkipBaseClasses);
|
||||
if (constructorMethodType) {
|
||||
this._validateCallArguments(node, constructorMethodType, true);
|
||||
validatedTypes = true;
|
||||
}
|
||||
|
||||
if (!validatedTypes && node.arguments.length > 0) {
|
||||
this._addError(
|
||||
`Expected no arguments to '${ type.getClassName() }' constructor`, node);
|
||||
}
|
||||
|
||||
return isValid;
|
||||
}
|
||||
|
||||
private _validateCallArguments(node: CallExpressionNode, callType: Type,
|
||||
@ -802,7 +823,8 @@ export class ExpressionEvaluator {
|
||||
} else if (callType instanceof ObjectType) {
|
||||
isCallable = false;
|
||||
let memberType = this._getTypeFromClassMemberString(
|
||||
'__call__', callType.getClassType(), false);
|
||||
'__call__', callType.getClassType(),
|
||||
MemberAccessFlags.SkipGetAttributeCheck | MemberAccessFlags.SkipInstanceMembers);
|
||||
|
||||
if (memberType && memberType instanceof FunctionType) {
|
||||
isCallable = this._validateCallArguments(node, memberType, true);
|
||||
|
@ -426,6 +426,9 @@ export class TypeAnalyzer extends ParseTreeWalker {
|
||||
let constExprValue = ExpressionUtils.evaluateConstantExpression(
|
||||
node.testExpression, this._fileInfo.executionEnvironment);
|
||||
|
||||
// Get and cache the expression type before walking it. This will apply
|
||||
// any type constraints along the way.
|
||||
this._getTypeOfExpression(node.testExpression);
|
||||
this.walk(node.testExpression);
|
||||
|
||||
let typeConstraints = this._buildTypeConstraints(node.testExpression);
|
||||
|
@ -236,7 +236,7 @@ export class TypeUtils {
|
||||
// defined by Python. For more detials, see this note on method resolution
|
||||
// order: https://www.python.org/download/releases/2.3/mro/.
|
||||
static lookUpClassMember(classType: Type, memberName: string,
|
||||
includeInstanceFields = true): ClassMember | undefined {
|
||||
includeInstanceFields = true, searchBaseClasses = true): ClassMember | undefined {
|
||||
|
||||
if (classType instanceof ClassType) {
|
||||
// TODO - for now, use naive depth-first search.
|
||||
@ -269,10 +269,13 @@ export class TypeUtils {
|
||||
};
|
||||
}
|
||||
|
||||
for (let baseClass of classType.getBaseClasses()) {
|
||||
let methodType = this.lookUpClassMember(baseClass.type, memberName);
|
||||
if (methodType) {
|
||||
return methodType;
|
||||
if (searchBaseClasses) {
|
||||
for (let baseClass of classType.getBaseClasses()) {
|
||||
let methodType = this.lookUpClassMember(baseClass.type,
|
||||
memberName, searchBaseClasses);
|
||||
if (methodType) {
|
||||
return methodType;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (classType.isAny()) {
|
||||
|
Loading…
Reference in New Issue
Block a user