From a801d787b9e8c379c4bb8eae72d3f491f3cf1e75 Mon Sep 17 00:00:00 2001 From: Eric Traut Date: Fri, 12 Nov 2021 00:30:57 -0800 Subject: [PATCH] Fixed regression in handling of callback protocols that define a `__name__` attribute, which is common to all functions. --- .../src/analyzer/typeEvaluator.ts | 15 +++++++- .../src/tests/samples/callbackProtocol5.py | 37 ++++++++++++++----- 2 files changed, 40 insertions(+), 12 deletions(-) diff --git a/packages/pyright-internal/src/analyzer/typeEvaluator.ts b/packages/pyright-internal/src/analyzer/typeEvaluator.ts index 824d50997..9bafc8f5a 100644 --- a/packages/pyright-internal/src/analyzer/typeEvaluator.ts +++ b/packages/pyright-internal/src/analyzer/typeEvaluator.ts @@ -480,6 +480,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions let noneType: Type | undefined; let objectType: Type | undefined; let typeClassType: Type | undefined; + let functionObj: Type | undefined; let tupleClassType: Type | undefined; let boolClassType: Type | undefined; let strClassType: Type | undefined; @@ -688,6 +689,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions objectType = getBuiltInObject(node, 'object'); typeClassType = getBuiltInType(node, 'type'); + functionObj = getBuiltInObject(node, 'function'); // Initialize and cache "Collection" to break a cyclical dependency // that occurs when resolving tuple below. @@ -4125,7 +4127,6 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions const functionType = isFunction(baseType) ? baseType : baseType.overloads[0]; type = functionType.boundToType; } else { - const functionObj = getBuiltInObject(node, 'function'); if (!functionObj) { type = AnyType.create(); } else { @@ -19537,7 +19538,17 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions if (isClass(mroClass) && ClassType.isProtocolClass(mroClass)) { for (const field of mroClass.details.fields) { if (field[0] !== '__call__' && !field[1].isIgnoredForProtocolMatch()) { - return undefined; + let fieldIsPartOfFunction = false; + + if (functionObj && isClass(functionObj)) { + if (functionObj.details.fields.has(field[0])) { + fieldIsPartOfFunction = true; + } + } + + if (!fieldIsPartOfFunction) { + return undefined; + } } } } diff --git a/packages/pyright-internal/src/tests/samples/callbackProtocol5.py b/packages/pyright-internal/src/tests/samples/callbackProtocol5.py index 3f6a02682..d43a360bf 100644 --- a/packages/pyright-internal/src/tests/samples/callbackProtocol5.py +++ b/packages/pyright-internal/src/tests/samples/callbackProtocol5.py @@ -1,7 +1,7 @@ # This sample tests the case where a callback protocol defines additional # attributes. -from typing import Callable, Literal, Protocol, TypeVar, cast +from typing import Any, Callable, Literal, Protocol, TypeVar, cast from typing_extensions import ParamSpec @@ -9,7 +9,7 @@ P = ParamSpec("P") R = TypeVar("R", covariant=True) -class SomeFunc(Protocol[P, R]): +class SomeFunc1(Protocol[P, R]): __name__: str other_attribute: int @@ -18,8 +18,8 @@ class SomeFunc(Protocol[P, R]): ... -def other_func(f: Callable[P, R]) -> SomeFunc[P, R]: - converted = cast(SomeFunc, f) +def other_func1(f: Callable[P, R]) -> SomeFunc1[P, R]: + converted = cast(SomeFunc1, f) print(converted.__name__) @@ -34,16 +34,33 @@ def other_func(f: Callable[P, R]) -> SomeFunc[P, R]: return converted -@other_func -def some_func(x: int) -> str: +@other_func1 +def some_func1(x: int) -> str: ... -t1: Literal["SomeFunc[(x: int), str]"] = reveal_type(some_func) +t1: Literal["SomeFunc1[(x: int), str]"] = reveal_type(some_func1) -some_func.other_attribute +some_func1.other_attribute # This should generate an error -some_func.other_attribute2 +some_func1.other_attribute2 -some_func(x=3) +some_func1(x=3) + + +class SomeFunc2(Protocol): + __name__: str + __module__: str + __qualname__: str + __annotations__: dict[str, Any] + + def __call__(self) -> None: + ... + + +def some_func2() -> None: + ... + + +v: SomeFunc2 = some_func2