diff --git a/packages/pyright-internal/src/analyzer/checker.ts b/packages/pyright-internal/src/analyzer/checker.ts index 037145622..61e203bb2 100644 --- a/packages/pyright-internal/src/analyzer/checker.ts +++ b/packages/pyright-internal/src/analyzer/checker.ts @@ -95,6 +95,7 @@ import { OperatorType, StringTokenFlags, TokenType } from '../parser/tokenizerTy import { AnalyzerFileInfo } from './analyzerFileInfo'; import * as AnalyzerNodeInfo from './analyzerNodeInfo'; import { getBoundCallMethod, getBoundInitMethod, getBoundNewMethod } from './constructors'; +import { addInheritedDataClassEntries } from './dataClasses'; import { Declaration, DeclarationType, isAliasDeclaration } from './declaration'; import { getNameNodeForDeclaration } from './declarationUtils'; import { deprecatedAliases, deprecatedSpecialForms } from './deprecatedSymbols'; @@ -165,6 +166,7 @@ import { AnyType, ClassType, ClassTypeFlags, + DataClassEntry, EnumLiteral, FunctionType, FunctionTypeFlags, @@ -5294,6 +5296,13 @@ export class Checker extends ParseTreeWalker { getProtocolSymbolsRecursive(classType, abstractSymbols, ClassTypeFlags.SupportsAbstractMethods); } + // If this is a dataclass, get all of the entries so we can tell which + // ones are initialized by the synthesized __init__ method. + const dataClassEntries: DataClassEntry[] = []; + if (ClassType.isDataClass(classType)) { + addInheritedDataClassEntries(classType, dataClassEntries); + } + ClassType.getSymbolTable(classType).forEach((localSymbol, name) => { abstractSymbols.delete(name); @@ -5376,20 +5385,30 @@ export class Checker extends ParseTreeWalker { return; } - if (decls[0].type === DeclarationType.Variable) { - // If none of the declarations involve assignments, assume it's - // not implemented in the protocol. - if (!decls.some((decl) => decl.type === DeclarationType.Variable && !!decl.inferredTypeSource)) { - // This is a variable declaration that is not implemented in the - // protocol base class. Make sure it's implemented in the derived class. - diagAddendum.addMessage( - LocAddendum.uninitializedAbstractVariable().format({ - name, - classType: member.classType.details.name, - }) - ); + if (decls[0].type !== DeclarationType.Variable) { + return; + } + + // Dataclass fields are typically exempted from this check because + // they have synthesized __init__ methods that initialize these variables. + const dcEntry = dataClassEntries?.find((entry) => entry.name === name); + if (dcEntry) { + if (dcEntry.includeInInit) { + return; + } + } else { + // Do one or more declarations involve assignments? + if (decls.some((decl) => decl.type === DeclarationType.Variable && !!decl.inferredTypeSource)) { + return; } } + + diagAddendum.addMessage( + LocAddendum.uninitializedAbstractVariable().format({ + name, + classType: member.classType.details.name, + }) + ); }); if (!diagAddendum.isEmpty()) { diff --git a/packages/pyright-internal/src/analyzer/dataClasses.ts b/packages/pyright-internal/src/analyzer/dataClasses.ts index 4d7c45994..8494e9087 100644 --- a/packages/pyright-internal/src/analyzer/dataClasses.ts +++ b/packages/pyright-internal/src/analyzer/dataClasses.ts @@ -995,7 +995,7 @@ function transformDescriptorType(evaluator: TypeEvaluator, type: Type): Type { // the specified class. These entries must be unique and in reverse-MRO // order. Returns true if all of the class types in the hierarchy are // known, false if one or more are unknown. -function addInheritedDataClassEntries(classType: ClassType, entries: DataClassEntry[]) { +export function addInheritedDataClassEntries(classType: ClassType, entries: DataClassEntry[]) { let allAncestorsAreKnown = true; ClassType.getReverseMro(classType).forEach((mroClass) => { diff --git a/packages/pyright-internal/src/tests/checker.test.ts b/packages/pyright-internal/src/tests/checker.test.ts index 5a359a640..2ec0da7e4 100644 --- a/packages/pyright-internal/src/tests/checker.test.ts +++ b/packages/pyright-internal/src/tests/checker.test.ts @@ -493,7 +493,7 @@ test('UninitializedVariable2', () => { // Enable it as an error. configOptions.diagnosticRuleSet.reportUninitializedInstanceVariable = 'error'; analysisResults = TestUtils.typeAnalyzeSampleFiles(['uninitializedVariable2.py'], configOptions); - TestUtils.validateResults(analysisResults, 2); + TestUtils.validateResults(analysisResults, 3); }); test('Deprecated1', () => { diff --git a/packages/pyright-internal/src/tests/samples/uninitializedVariable2.py b/packages/pyright-internal/src/tests/samples/uninitializedVariable2.py index 12fc05f52..a2214e959 100644 --- a/packages/pyright-internal/src/tests/samples/uninitializedVariable2.py +++ b/packages/pyright-internal/src/tests/samples/uninitializedVariable2.py @@ -2,7 +2,8 @@ # to a concrete implementation of an abstract base class that defines # (but does not assign) variables. -from abc import ABC +from abc import ABC, abstractmethod +from dataclasses import dataclass, field from typing import NamedTuple, final @@ -52,3 +53,16 @@ class G(Abstract3): class H(NamedTuple): x: int + + +@dataclass +class IAbstract(ABC): + p1: str + p2: int = field(init=False) + + +@final +@dataclass +# This should generate an error because p2 is uninitialized. +class I(IAbstract): + p3: int