From 32076952c7d9eeaf7363f009dbcbad09b3e9cef6 Mon Sep 17 00:00:00 2001 From: Eric Traut Date: Thu, 6 Jun 2024 10:29:53 -0700 Subject: [PATCH] Fixed bug that results in a false positive when accessing a generic attribute in a base class from a subclass that explicitly specializes the generic type. This addresses #8083. (#8086) --- .../src/analyzer/typeEvaluator.ts | 37 +++++---- .../src/analyzer/typeUtils.ts | 27 ++++++- .../pyright-internal/src/analyzer/types.ts | 4 +- .../src/tests/samples/memberAccess1.py | 30 -------- .../src/tests/samples/memberAccess25.py | 75 +++++++++++++++++++ .../src/tests/typeEvaluator4.test.ts | 7 +- 6 files changed, 127 insertions(+), 53 deletions(-) create mode 100644 packages/pyright-internal/src/tests/samples/memberAccess25.py diff --git a/packages/pyright-internal/src/analyzer/typeEvaluator.ts b/packages/pyright-internal/src/analyzer/typeEvaluator.ts index 8ef9b78ff..383d2e405 100644 --- a/packages/pyright-internal/src/analyzer/typeEvaluator.ts +++ b/packages/pyright-internal/src/analyzer/typeEvaluator.ts @@ -22392,26 +22392,31 @@ export function createTypeEvaluator( // prior to specializing. inferReturnTypeIfNecessary(typeResult.type); + // Check for ambiguous accesses to attributes with generic types? if ( + errorNode && + selfClass && + isClass(selfClass) && member.isInstanceMember && - (flags & MemberAccessFlags.DisallowGenericInstanceVariableAccess) !== 0 + isClass(member.unspecializedClassType) && + (flags & MemberAccessFlags.DisallowGenericInstanceVariableAccess) !== 0 && + requiresSpecialization(typeResult.type, { ignoreSelf: true, ignoreImplicitTypeArgs: true }) ) { - let isGenericNonCallable = false; + const specializedType = partiallySpecializeType( + typeResult.type, + member.unspecializedClassType, + selfSpecializeClass(selfClass, /* overrideTypeArgs */ true) + ); - doForEachSubtype(typeResult.type, (subtype) => { - if (!isAnyOrUnknown(subtype) && !isFunction(subtype) && !isOverloadedFunction(subtype)) { - if ( - requiresSpecialization(typeResult.type, { - ignoreSelf: true, - ignoreImplicitTypeArgs: true, - }) - ) { - isGenericNonCallable = true; - } - } - }); - - if (isGenericNonCallable && errorNode) { + if ( + findSubtype( + specializedType, + (subtype) => + !isFunction(subtype) && + !isOverloadedFunction(subtype) && + requiresSpecialization(subtype, { ignoreSelf: true, ignoreImplicitTypeArgs: true }) + ) + ) { addDiagnostic( DiagnosticRule.reportGeneralTypeIssues, LocMessage.genericInstanceVariableAccess(), diff --git a/packages/pyright-internal/src/analyzer/typeUtils.ts b/packages/pyright-internal/src/analyzer/typeUtils.ts index 085885d11..a166223c4 100644 --- a/packages/pyright-internal/src/analyzer/typeUtils.ts +++ b/packages/pyright-internal/src/analyzer/typeUtils.ts @@ -73,6 +73,9 @@ export interface ClassMember { // Partially-specialized class that contains the class member classType: ClassType | UnknownType | AnyType; + // Unspecialized class that contains the class member + unspecializedClassType: ClassType | UnknownType | AnyType; + // True if it is an instance or class member; it can be both a class and // an instance member in cases where a class variable is overridden // by an instance variable @@ -1141,8 +1144,12 @@ export function getUnknownTypeForCallable(): FunctionType { // If the class is generic and not already specialized, this function // "self specializes" the class, filling in its own type parameters // as type arguments. -export function selfSpecializeClass(type: ClassType): ClassType { - if (type.details.typeParameters.length === 0 || type.typeArguments) { +export function selfSpecializeClass(type: ClassType, overrideTypeArgs = false): ClassType { + if (type.details.typeParameters.length === 0) { + return type; + } + + if (type.typeArguments && !overrideTypeArgs) { return type; } @@ -1640,6 +1647,7 @@ export function getProtocolSymbolsRecursive( symbolMap.set(name, { symbol, classType, + unspecializedClassType: classType, isInstanceMember: symbol.isInstanceMember(), isClassMember: symbol.isClassMember(), isClassVar: isEffectivelyClassVar(symbol, /* isDataclass */ false), @@ -1772,6 +1780,8 @@ export function* getClassMemberIterator( for (const [mroClass, specializedMroClass] of classItr) { if (!isInstantiableClass(mroClass)) { if (!declaredTypesOnly) { + const classType = isAnyOrUnknown(mroClass) ? mroClass : UnknownType.create(); + // The class derives from an unknown type, so all bets are off // when trying to find a member. Return an unknown symbol. const cm: ClassMember = { @@ -1779,7 +1789,8 @@ export function* getClassMemberIterator( isInstanceMember: false, isClassMember: true, isClassVar: false, - classType: isAnyOrUnknown(mroClass) ? mroClass : UnknownType.create(), + classType, + unspecializedClassType: classType, isTypeDeclared: false, skippedUndeclaredType: false, }; @@ -1806,6 +1817,7 @@ export function* getClassMemberIterator( isClassMember: symbol.isClassMember(), isClassVar: isEffectivelyClassVar(symbol, ClassType.isDataClass(specializedMroClass)), classType: specializedMroClass, + unspecializedClassType: mroClass, isTypeDeclared: hasDeclaredType, skippedUndeclaredType, }; @@ -1846,6 +1858,7 @@ export function* getClassMemberIterator( isClassMember, isClassVar: isEffectivelyClassVar(symbol, isDataclass), classType: specializedMroClass, + unspecializedClassType: mroClass, isTypeDeclared: hasDeclaredType, skippedUndeclaredType, }; @@ -1865,6 +1878,7 @@ export function* getClassMemberIterator( isClassMember: true, isClassVar: false, classType, + unspecializedClassType: classType, isTypeDeclared: false, skippedUndeclaredType: false, }; @@ -1936,6 +1950,7 @@ export function getClassFieldsRecursive(classType: ClassType): Map None: - ... - - -ClassG[int].y = 1 -ClassG[int].y -del ClassG[int].y - -ClassG.y = 1 -ClassG.y -del ClassG.y - -# This should generate an error because x is generic. -ClassG[int].x = 1 - -# This should generate an error because x is generic. -ClassG[int].x - -# This should generate an error because x is generic. -del ClassG[int].x - -# This should generate an error because x is generic. -ClassG.x - -# This should generate an error because x is generic. -del ClassG.x diff --git a/packages/pyright-internal/src/tests/samples/memberAccess25.py b/packages/pyright-internal/src/tests/samples/memberAccess25.py new file mode 100644 index 000000000..411723d7c --- /dev/null +++ b/packages/pyright-internal/src/tests/samples/memberAccess25.py @@ -0,0 +1,75 @@ +# This sample tests for the check of a member access through a generic class +# when the type of the attribute is generic (and therefore its type is +# ambiguous). + +from typing import Generic, TypeVar + +T = TypeVar("T") + + +class ClassA(Generic[T]): + x: T + y: int + + def __init__(self, label: T | None = None) -> None: ... + + +ClassA[int].y = 1 +ClassA[int].y +del ClassA[int].y + +ClassA.y = 1 +ClassA.y +del ClassA.y + +# This should generate an error because x is generic. +ClassA[int].x = 1 + +# This should generate an error because x is generic. +ClassA[int].x + +# This should generate an error because x is generic. +del ClassA[int].x + +# This should generate an error because x is generic. +ClassA.x = 1 + +# This should generate an error because x is generic. +ClassA.x + +# This should generate an error because x is generic. +del ClassA.x + + +class ClassB(ClassA[T]): + pass + + +# This should generate an error because x is generic. +ClassB[int].x = 1 + +# This should generate an error because x is generic. +ClassB[int].x + +# This should generate an error because x is generic. +del ClassB[int].x + +# This should generate an error because x is generic. +ClassB.x = 1 + +# This should generate an error because x is generic. +ClassB.x + +# This should generate an error because x is generic. +del ClassB.x + + +class ClassC(ClassA[int]): + pass + + +ClassC.x = 1 +ClassC.x +del ClassC.x +ClassC.x +del ClassC.x diff --git a/packages/pyright-internal/src/tests/typeEvaluator4.test.ts b/packages/pyright-internal/src/tests/typeEvaluator4.test.ts index a99994d62..654292f10 100644 --- a/packages/pyright-internal/src/tests/typeEvaluator4.test.ts +++ b/packages/pyright-internal/src/tests/typeEvaluator4.test.ts @@ -118,7 +118,7 @@ test('FString5', () => { test('MemberAccess1', () => { const analysisResults = TestUtils.typeAnalyzeSampleFiles(['memberAccess1.py']); - TestUtils.validateResults(analysisResults, 5); + TestUtils.validateResults(analysisResults, 0); }); test('MemberAccess2', () => { @@ -236,6 +236,11 @@ test('MemberAccess24', () => { TestUtils.validateResults(analysisResults, 0); }); +test('MemberAccess25', () => { + const analysisResults = TestUtils.typeAnalyzeSampleFiles(['memberAccess25.py']); + TestUtils.validateResults(analysisResults, 12); +}); + test('DataClassNamedTuple1', () => { const analysisResults = TestUtils.typeAnalyzeSampleFiles(['dataclassNamedTuple1.py']);