diff --git a/packages/pyright-internal/src/analyzer/typeEvaluator.ts b/packages/pyright-internal/src/analyzer/typeEvaluator.ts index ab0e17245..9c549cd0e 100644 --- a/packages/pyright-internal/src/analyzer/typeEvaluator.ts +++ b/packages/pyright-internal/src/analyzer/typeEvaluator.ts @@ -20050,6 +20050,12 @@ export function createTypeEvaluator( addError(LocMessage.protocolNotAllowed(), errorNode); } + typeArgs?.forEach((typeArg) => { + if (typeArg.typeList || !isTypeVar(typeArg.type)) { + addError(LocMessage.protocolTypeArgMustBeTypeParam(), typeArg.node); + } + }); + return { type: createSpecialType( classType, diff --git a/packages/pyright-internal/src/localization/localize.ts b/packages/pyright-internal/src/localization/localize.ts index c0d0d9876..06c94d41f 100644 --- a/packages/pyright-internal/src/localization/localize.ts +++ b/packages/pyright-internal/src/localization/localize.ts @@ -811,6 +811,7 @@ export namespace Localizer { export const protocolBaseClassWithTypeArgs = () => getRawString('Diagnostic.protocolBaseClassWithTypeArgs'); export const protocolIllegal = () => getRawString('Diagnostic.protocolIllegal'); export const protocolNotAllowed = () => getRawString('Diagnostic.protocolNotAllowed'); + export const protocolTypeArgMustBeTypeParam = () => getRawString('Diagnostic.protocolTypeArgMustBeTypeParam'); export const protocolUnsafeOverlap = () => new ParameterizedString<{ name: string }>(getRawString('Diagnostic.protocolUnsafeOverlap')); export const protocolVarianceContravariant = () => diff --git a/packages/pyright-internal/src/localization/package.nls.en-us.json b/packages/pyright-internal/src/localization/package.nls.en-us.json index 4bbd301f1..f12de1f81 100644 --- a/packages/pyright-internal/src/localization/package.nls.en-us.json +++ b/packages/pyright-internal/src/localization/package.nls.en-us.json @@ -389,6 +389,7 @@ "protocolBaseClassWithTypeArgs": "Type arguments are not allowed with Protocol class when using type parameter syntax", "protocolIllegal": "Use of \"Protocol\" requires Python 3.7 or newer", "protocolNotAllowed": "\"Protocol\" cannot be used in this context", + "protocolTypeArgMustBeTypeParam": "Type argument for \"Protocol\" must be a type parameter", "protocolUnsafeOverlap": "Class overlaps \"{name}\" unsafely and could produce a match at runtime", "protocolVarianceContravariant": "Type variable \"{variable}\" used in generic protocol \"{class}\" should be contravariant", "protocolVarianceCovariant": "Type variable \"{variable}\" used in generic protocol \"{class}\" should be covariant", diff --git a/packages/pyright-internal/src/tests/samples/protocol1.py b/packages/pyright-internal/src/tests/samples/protocol1.py index 5d190da5e..ccf66cada 100644 --- a/packages/pyright-internal/src/tests/samples/protocol1.py +++ b/packages/pyright-internal/src/tests/samples/protocol1.py @@ -1,6 +1,6 @@ # This sample tests the type checker's handling of generic protocol types. -from typing import Generic, TypeVar, Protocol +from typing import Generic, Protocol, TypeVar T = TypeVar("T") T_co = TypeVar("T_co", covariant=True) @@ -8,13 +8,11 @@ T_contra = TypeVar("T_contra", contravariant=True) class Box(Protocol[T_co]): - def content(self) -> T_co: - ... + def content(self) -> T_co: ... class Box_Impl: - def content(self) -> int: - ... + def content(self) -> int: ... box: Box[float] @@ -25,13 +23,11 @@ box = second_box class Sender(Protocol[T_contra]): - def send(self, data: T_contra) -> int: - ... + def send(self, data: T_contra) -> int: ... class Sender_Impl: - def send(self, data: float) -> int: - ... + def send(self, data: float) -> int: ... sender: Sender[float] = Sender_Impl() @@ -77,8 +73,7 @@ var2: list[Protocol] = [] class Abstract1(Protocol[T_contra]): - def do(self, x: T_contra | None): - ... + def do(self, x: T_contra | None): ... class Concrete1: @@ -95,34 +90,28 @@ use_protocol1(Concrete1()) # This should generate an error because TypeVars cannot # be defined in both Protocol and Generic. -class Proto2(Protocol[T_co], Generic[T_co]): - ... +class Proto2(Protocol[T_co], Generic[T_co]): ... -class Proto3(Protocol, Generic[T_co]): - ... +class Proto3(Protocol, Generic[T_co]): ... _A = TypeVar("_A", covariant=True) _B = TypeVar("_B", covariant=True, bound=int) -class ProtoBase1(Protocol[_A, _B]): - ... +class ProtoBase1(Protocol[_A, _B]): ... # This should generate an error because Protocol must # include all of the TypeVars. -class Proto4(ProtoBase1[_A, _B], Protocol[_A]): - ... +class Proto4(ProtoBase1[_A, _B], Protocol[_A]): ... -class ProtoBase2(Protocol[_B]): - ... +class ProtoBase2(Protocol[_B]): ... -class Proto5(ProtoBase2[_B], Protocol[_A, _B]): - ... +class Proto5(ProtoBase2[_B], Protocol[_A, _B]): ... p5_1: Proto5[float, int] @@ -141,3 +130,8 @@ def func1(): # This should generate an error because Protocol isn't # allowed in a TypeVar bound. T = TypeVar("T", bound=Protocol | int) + + +# This should generate an error because int is not a TypeVar +class Proto6(Protocol[int]): + pass diff --git a/packages/pyright-internal/src/tests/typeEvaluator7.test.ts b/packages/pyright-internal/src/tests/typeEvaluator7.test.ts index fdb201919..056c437fa 100644 --- a/packages/pyright-internal/src/tests/typeEvaluator7.test.ts +++ b/packages/pyright-internal/src/tests/typeEvaluator7.test.ts @@ -286,7 +286,7 @@ test('GenericType45', () => { test('Protocol1', () => { const analysisResults = TestUtils.typeAnalyzeSampleFiles(['protocol1.py']); - TestUtils.validateResults(analysisResults, 8); + TestUtils.validateResults(analysisResults, 9); }); test('Protocol2', () => {