Modified the overload matching algorithm to match the behavior of mypy when the overload match is ambiguous because an argument evaluates to Any or Unknown. In this case, the call expression evaluates to Unknown. Previously, pyright used the first of the matching overloads in this case. This addresses https://github.com/microsoft/pyright/issues/4347.

This commit is contained in:
Eric Traut 2022-12-16 23:18:37 -08:00
parent 7770b1b6d9
commit 1dc84b93b1
7 changed files with 94 additions and 4 deletions

View File

@ -400,7 +400,7 @@ reveal_type(Child.method2()) # Type[Child]
### Overloads
Some functions or methods can return one of several different types. In cases where the return type depends on the types of the input parameters, it is useful to specify this using a series of `@overload` signatures. When Pyright evaluates a call expression, it determines which overload signature best matches the supplied arguments.
Some functions or methods can return one of several different types. In cases where the return type depends on the types of the input arguments, it is useful to specify this using a series of `@overload` signatures. When Pyright evaluates a call expression, it determines which overload signature best matches the supplied arguments.
[PEP 484](https://www.python.org/dev/peps/pep-0484/#function-method-overloading) introduced the `@overload` decorator and described how it can be used, but the PEP did not specify precisely how a type checker should choose the “best” overload. Pyright uses the following rules.
@ -410,7 +410,7 @@ Some functions or methods can return one of several different types. In cases wh
3. If only one overload remains, it is the “winner”.
4. If more than one overload remains, the “winner” is chosen based on the order in which the overloads are declared. In general, the first remaining overload is the “winner”. One exception to this rule is when a `*args` (unpacked) argument matches a `*args` parameter in one of the overload signatures. This situation overrides the normal order-based rule.
4. If more than one overload remains, the “winner” is chosen based on the order in which the overloads are declared. In general, the first remaining overload is the “winner”. One exception to this rule is when a `*args` (unpacked) argument matches a `*args` parameter in one of the overload signatures. This situation overrides the normal order-based rule. Another exception is when two or more overloads match because an argument evaluates to `Any` or `Unknown`. In this situation, the matching overload is ambiguous, so the call expression evaluates to `Unknown`.
5. If no overloads remain, Pyright considers whether any of the arguments are union types. If so, these union types are expanded into their constituent subtypes, and the entire process of overload matching is repeated with the expanded argument types. If two or more overloads match, the union of their respective return types form the final return type for the call expression.

View File

@ -7680,6 +7680,8 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
let matchedOverload: FunctionType | undefined;
const argTypeOverride = expandedArgTypes[expandedTypesIndex];
const hasArgTypeOverride = argTypeOverride.some((a) => a !== undefined);
const possibleMatchResults: Type[] = [];
let isDefinitiveMatchFound = false;
for (let overloadIndex = 0; overloadIndex < argParamMatches.length; overloadIndex++) {
const overload = argParamMatches[overloadIndex].overload;
@ -7732,8 +7734,27 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
matchResults,
typeVarContext: effectiveTypeVarContext,
});
returnTypes.push(callResult.returnType);
break;
if (callResult.isArgumentAnyOrUnknown) {
possibleMatchResults.push(callResult.returnType);
} else {
returnTypes.push(callResult.returnType);
isDefinitiveMatchFound = true;
break;
}
}
}
// If we didn't find a definitive match that doesn't depend on
// an Any or Unknown argument, fall back on the possible match.
// If there were multiple possible matches, evaluate the type as
// Unknown, but include the "possible types" to allow for completion
// suggestions.
if (!isDefinitiveMatchFound) {
if (possibleMatchResults.length > 1) {
returnTypes.push(UnknownType.createPossibleType(combineTypes(possibleMatchResults)));
} else {
returnTypes.push(possibleMatchResults[0]);
}
}
@ -7769,6 +7790,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
return {
argumentErrors: false,
isArgumentAnyOrUnknown: finalCallResult.isArgumentAnyOrUnknown,
returnType: combineTypes(returnTypes),
isTypeIncomplete,
specializedInitSelfType: finalCallResult.specializedInitSelfType,
@ -10160,6 +10182,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
let isTypeIncomplete = matchResults.isTypeIncomplete;
let argumentErrors = false;
let specializedInitSelfType: Type | undefined;
let isArgumentAnyOrUnknown = false;
const typeCondition = getTypeCondition(type);
if (type.boundTypeVarScopeId) {
@ -10291,6 +10314,10 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
condition = TypeCondition.combine(condition, argResult.condition) ?? [];
}
if (isAnyOrUnknown(argResult.argType)) {
isArgumentAnyOrUnknown = true;
}
if (type.details.paramSpec) {
if (argParam.argument.argumentCategory === ArgumentCategory.UnpackedList) {
if (isParamSpec(argResult.argType) && argResult.argType.paramSpecAccess === 'args') {
@ -10432,6 +10459,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
return {
argumentErrors,
isArgumentAnyOrUnknown,
returnType: specializedReturnType,
isTypeIncomplete,
activeParam: matchResults.activeParam,

View File

@ -307,6 +307,9 @@ export interface CallResult {
// Were any errors discovered when evaluating argument types?
argumentErrors: boolean;
// Did one or more arguments evaluated to Any or Unknown?
isArgumentAnyOrUnknown?: boolean;
// The parameter associated with the "active" argument (used
// for signature help provider)
activeParam?: FunctionParameter | undefined;

View File

@ -265,6 +265,12 @@ export namespace UnboundType {
export interface UnknownType extends TypeBase {
category: TypeCategory.Unknown;
isIncomplete: boolean;
// A "possible type" is a form of a "weak union" where the actual
// type is unknown, but it could be one of the subtypes in the union.
// This is used for overload matching in cases where more than one
// overload matches due to an argument that evaluates to Any or Unknown.
possibleType?: Type;
}
export namespace UnknownType {
@ -282,6 +288,17 @@ export namespace UnknownType {
export function create(isIncomplete = false) {
return isIncomplete ? _incompleteInstance : _instance;
}
export function createPossibleType(possibleType: Type) {
const unknownWithPossibleType: UnknownType = {
category: TypeCategory.Unknown,
flags: TypeFlags.Instantiable | TypeFlags.Instance,
isIncomplete: false,
possibleType,
};
return unknownWithPossibleType;
}
}
export interface ModuleType extends TypeBase {

View File

@ -1446,6 +1446,12 @@ export class CompletionProvider {
if (leftType) {
leftType = this._evaluator.makeTopLevelTypeVarsConcrete(leftType);
// If this is an unknown type with a "possible type" associated with
// it, use the possible type.
if (isUnknown(leftType) && leftType.possibleType) {
leftType = this._evaluator.makeTopLevelTypeVarsConcrete(leftType.possibleType);
}
doForEachSubtype(leftType, (subtype) => {
subtype = this._evaluator.makeTopLevelTypeVarsConcrete(subtype);

View File

@ -0,0 +1,31 @@
# This sample tests overload matching in cases where one or more
# matches are found due to an Any or Unknown argument.
from typing import Any, overload
@overload
def func1(x: int, y: float) -> float:
...
@overload
def func1(x: str, y: float) -> str:
...
def func1(x: str | int, y: float) -> float | str:
...
def func2(a: Any):
v1 = func1(1, 3.4)
reveal_type(v1, expected_text="float")
v2 = func1("", 3.4)
reveal_type(v2, expected_text="str")
v3 = func1(a, 3.4)
reveal_type(v3, expected_text="Unknown")
v4 = func1("", a)
reveal_type(v4, expected_text="str")

View File

@ -304,6 +304,11 @@ test('Overload11', () => {
TestUtils.validateResults(analysisResults, 1);
});
test('Overload12', () => {
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['overload12.py']);
TestUtils.validateResults(analysisResults, 0);
});
test('Final1', () => {
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['final1.py']);
TestUtils.validateResults(analysisResults, 1);