Fixed regression in handling of callback protocols that define a __name__ attribute, which is common to all functions.

This commit is contained in:
Eric Traut 2021-11-12 00:30:57 -08:00
parent 52a241cf42
commit a801d787b9
2 changed files with 40 additions and 12 deletions

View File

@ -480,6 +480,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
let noneType: Type | undefined;
let objectType: Type | undefined;
let typeClassType: Type | undefined;
let functionObj: Type | undefined;
let tupleClassType: Type | undefined;
let boolClassType: Type | undefined;
let strClassType: Type | undefined;
@ -688,6 +689,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
objectType = getBuiltInObject(node, 'object');
typeClassType = getBuiltInType(node, 'type');
functionObj = getBuiltInObject(node, 'function');
// Initialize and cache "Collection" to break a cyclical dependency
// that occurs when resolving tuple below.
@ -4125,7 +4127,6 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
const functionType = isFunction(baseType) ? baseType : baseType.overloads[0];
type = functionType.boundToType;
} else {
const functionObj = getBuiltInObject(node, 'function');
if (!functionObj) {
type = AnyType.create();
} else {
@ -19537,7 +19538,17 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
if (isClass(mroClass) && ClassType.isProtocolClass(mroClass)) {
for (const field of mroClass.details.fields) {
if (field[0] !== '__call__' && !field[1].isIgnoredForProtocolMatch()) {
return undefined;
let fieldIsPartOfFunction = false;
if (functionObj && isClass(functionObj)) {
if (functionObj.details.fields.has(field[0])) {
fieldIsPartOfFunction = true;
}
}
if (!fieldIsPartOfFunction) {
return undefined;
}
}
}
}

View File

@ -1,7 +1,7 @@
# This sample tests the case where a callback protocol defines additional
# attributes.
from typing import Callable, Literal, Protocol, TypeVar, cast
from typing import Any, Callable, Literal, Protocol, TypeVar, cast
from typing_extensions import ParamSpec
@ -9,7 +9,7 @@ P = ParamSpec("P")
R = TypeVar("R", covariant=True)
class SomeFunc(Protocol[P, R]):
class SomeFunc1(Protocol[P, R]):
__name__: str
other_attribute: int
@ -18,8 +18,8 @@ class SomeFunc(Protocol[P, R]):
...
def other_func(f: Callable[P, R]) -> SomeFunc[P, R]:
converted = cast(SomeFunc, f)
def other_func1(f: Callable[P, R]) -> SomeFunc1[P, R]:
converted = cast(SomeFunc1, f)
print(converted.__name__)
@ -34,16 +34,33 @@ def other_func(f: Callable[P, R]) -> SomeFunc[P, R]:
return converted
@other_func
def some_func(x: int) -> str:
@other_func1
def some_func1(x: int) -> str:
...
t1: Literal["SomeFunc[(x: int), str]"] = reveal_type(some_func)
t1: Literal["SomeFunc1[(x: int), str]"] = reveal_type(some_func1)
some_func.other_attribute
some_func1.other_attribute
# This should generate an error
some_func.other_attribute2
some_func1.other_attribute2
some_func(x=3)
some_func1(x=3)
class SomeFunc2(Protocol):
__name__: str
__module__: str
__qualname__: str
__annotations__: dict[str, Any]
def __call__(self) -> None:
...
def some_func2() -> None:
...
v: SomeFunc2 = some_func2