From 1dc84b93b142da409ee506aea057a8a7d7bc7506 Mon Sep 17 00:00:00 2001 From: Eric Traut Date: Fri, 16 Dec 2022 23:18:37 -0800 Subject: [PATCH] Modified the overload matching algorithm to match the behavior of mypy when the overload match is ambiguous because an argument evaluates to `Any` or `Unknown`. In this case, the call expression evaluates to `Unknown`. Previously, pyright used the first of the matching overloads in this case. This addresses https://github.com/microsoft/pyright/issues/4347. --- docs/type-concepts.md | 4 +-- .../src/analyzer/typeEvaluator.ts | 32 +++++++++++++++++-- .../src/analyzer/typeEvaluatorTypes.ts | 3 ++ .../pyright-internal/src/analyzer/types.ts | 17 ++++++++++ .../src/languageService/completionProvider.ts | 6 ++++ .../src/tests/samples/overload12.py | 31 ++++++++++++++++++ .../src/tests/typeEvaluator4.test.ts | 5 +++ 7 files changed, 94 insertions(+), 4 deletions(-) create mode 100644 packages/pyright-internal/src/tests/samples/overload12.py diff --git a/docs/type-concepts.md b/docs/type-concepts.md index 6f78b117c..263817376 100644 --- a/docs/type-concepts.md +++ b/docs/type-concepts.md @@ -400,7 +400,7 @@ reveal_type(Child.method2()) # Type[Child] ### Overloads -Some functions or methods can return one of several different types. In cases where the return type depends on the types of the input parameters, it is useful to specify this using a series of `@overload` signatures. When Pyright evaluates a call expression, it determines which overload signature best matches the supplied arguments. +Some functions or methods can return one of several different types. In cases where the return type depends on the types of the input arguments, it is useful to specify this using a series of `@overload` signatures. When Pyright evaluates a call expression, it determines which overload signature best matches the supplied arguments. [PEP 484](https://www.python.org/dev/peps/pep-0484/#function-method-overloading) introduced the `@overload` decorator and described how it can be used, but the PEP did not specify precisely how a type checker should choose the “best” overload. Pyright uses the following rules. @@ -410,7 +410,7 @@ Some functions or methods can return one of several different types. In cases wh 3. If only one overload remains, it is the “winner”. -4. If more than one overload remains, the “winner” is chosen based on the order in which the overloads are declared. In general, the first remaining overload is the “winner”. One exception to this rule is when a `*args` (unpacked) argument matches a `*args` parameter in one of the overload signatures. This situation overrides the normal order-based rule. +4. If more than one overload remains, the “winner” is chosen based on the order in which the overloads are declared. In general, the first remaining overload is the “winner”. One exception to this rule is when a `*args` (unpacked) argument matches a `*args` parameter in one of the overload signatures. This situation overrides the normal order-based rule. Another exception is when two or more overloads match because an argument evaluates to `Any` or `Unknown`. In this situation, the matching overload is ambiguous, so the call expression evaluates to `Unknown`. 5. If no overloads remain, Pyright considers whether any of the arguments are union types. If so, these union types are expanded into their constituent subtypes, and the entire process of overload matching is repeated with the expanded argument types. If two or more overloads match, the union of their respective return types form the final return type for the call expression. diff --git a/packages/pyright-internal/src/analyzer/typeEvaluator.ts b/packages/pyright-internal/src/analyzer/typeEvaluator.ts index 86c012b1d..979211ce8 100644 --- a/packages/pyright-internal/src/analyzer/typeEvaluator.ts +++ b/packages/pyright-internal/src/analyzer/typeEvaluator.ts @@ -7680,6 +7680,8 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions let matchedOverload: FunctionType | undefined; const argTypeOverride = expandedArgTypes[expandedTypesIndex]; const hasArgTypeOverride = argTypeOverride.some((a) => a !== undefined); + const possibleMatchResults: Type[] = []; + let isDefinitiveMatchFound = false; for (let overloadIndex = 0; overloadIndex < argParamMatches.length; overloadIndex++) { const overload = argParamMatches[overloadIndex].overload; @@ -7732,8 +7734,27 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions matchResults, typeVarContext: effectiveTypeVarContext, }); - returnTypes.push(callResult.returnType); - break; + + if (callResult.isArgumentAnyOrUnknown) { + possibleMatchResults.push(callResult.returnType); + } else { + returnTypes.push(callResult.returnType); + isDefinitiveMatchFound = true; + break; + } + } + } + + // If we didn't find a definitive match that doesn't depend on + // an Any or Unknown argument, fall back on the possible match. + // If there were multiple possible matches, evaluate the type as + // Unknown, but include the "possible types" to allow for completion + // suggestions. + if (!isDefinitiveMatchFound) { + if (possibleMatchResults.length > 1) { + returnTypes.push(UnknownType.createPossibleType(combineTypes(possibleMatchResults))); + } else { + returnTypes.push(possibleMatchResults[0]); } } @@ -7769,6 +7790,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions return { argumentErrors: false, + isArgumentAnyOrUnknown: finalCallResult.isArgumentAnyOrUnknown, returnType: combineTypes(returnTypes), isTypeIncomplete, specializedInitSelfType: finalCallResult.specializedInitSelfType, @@ -10160,6 +10182,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions let isTypeIncomplete = matchResults.isTypeIncomplete; let argumentErrors = false; let specializedInitSelfType: Type | undefined; + let isArgumentAnyOrUnknown = false; const typeCondition = getTypeCondition(type); if (type.boundTypeVarScopeId) { @@ -10291,6 +10314,10 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions condition = TypeCondition.combine(condition, argResult.condition) ?? []; } + if (isAnyOrUnknown(argResult.argType)) { + isArgumentAnyOrUnknown = true; + } + if (type.details.paramSpec) { if (argParam.argument.argumentCategory === ArgumentCategory.UnpackedList) { if (isParamSpec(argResult.argType) && argResult.argType.paramSpecAccess === 'args') { @@ -10432,6 +10459,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions return { argumentErrors, + isArgumentAnyOrUnknown, returnType: specializedReturnType, isTypeIncomplete, activeParam: matchResults.activeParam, diff --git a/packages/pyright-internal/src/analyzer/typeEvaluatorTypes.ts b/packages/pyright-internal/src/analyzer/typeEvaluatorTypes.ts index f68d32c5a..8a5a1b5eb 100644 --- a/packages/pyright-internal/src/analyzer/typeEvaluatorTypes.ts +++ b/packages/pyright-internal/src/analyzer/typeEvaluatorTypes.ts @@ -307,6 +307,9 @@ export interface CallResult { // Were any errors discovered when evaluating argument types? argumentErrors: boolean; + // Did one or more arguments evaluated to Any or Unknown? + isArgumentAnyOrUnknown?: boolean; + // The parameter associated with the "active" argument (used // for signature help provider) activeParam?: FunctionParameter | undefined; diff --git a/packages/pyright-internal/src/analyzer/types.ts b/packages/pyright-internal/src/analyzer/types.ts index 35c781f80..3d2ef4b11 100644 --- a/packages/pyright-internal/src/analyzer/types.ts +++ b/packages/pyright-internal/src/analyzer/types.ts @@ -265,6 +265,12 @@ export namespace UnboundType { export interface UnknownType extends TypeBase { category: TypeCategory.Unknown; isIncomplete: boolean; + + // A "possible type" is a form of a "weak union" where the actual + // type is unknown, but it could be one of the subtypes in the union. + // This is used for overload matching in cases where more than one + // overload matches due to an argument that evaluates to Any or Unknown. + possibleType?: Type; } export namespace UnknownType { @@ -282,6 +288,17 @@ export namespace UnknownType { export function create(isIncomplete = false) { return isIncomplete ? _incompleteInstance : _instance; } + + export function createPossibleType(possibleType: Type) { + const unknownWithPossibleType: UnknownType = { + category: TypeCategory.Unknown, + flags: TypeFlags.Instantiable | TypeFlags.Instance, + isIncomplete: false, + possibleType, + }; + + return unknownWithPossibleType; + } } export interface ModuleType extends TypeBase { diff --git a/packages/pyright-internal/src/languageService/completionProvider.ts b/packages/pyright-internal/src/languageService/completionProvider.ts index ab46599fc..0d98d6d07 100644 --- a/packages/pyright-internal/src/languageService/completionProvider.ts +++ b/packages/pyright-internal/src/languageService/completionProvider.ts @@ -1446,6 +1446,12 @@ export class CompletionProvider { if (leftType) { leftType = this._evaluator.makeTopLevelTypeVarsConcrete(leftType); + // If this is an unknown type with a "possible type" associated with + // it, use the possible type. + if (isUnknown(leftType) && leftType.possibleType) { + leftType = this._evaluator.makeTopLevelTypeVarsConcrete(leftType.possibleType); + } + doForEachSubtype(leftType, (subtype) => { subtype = this._evaluator.makeTopLevelTypeVarsConcrete(subtype); diff --git a/packages/pyright-internal/src/tests/samples/overload12.py b/packages/pyright-internal/src/tests/samples/overload12.py new file mode 100644 index 000000000..3ad5a77f0 --- /dev/null +++ b/packages/pyright-internal/src/tests/samples/overload12.py @@ -0,0 +1,31 @@ +# This sample tests overload matching in cases where one or more +# matches are found due to an Any or Unknown argument. + +from typing import Any, overload + + +@overload +def func1(x: int, y: float) -> float: + ... + +@overload +def func1(x: str, y: float) -> str: + ... + +def func1(x: str | int, y: float) -> float | str: + ... + + +def func2(a: Any): + v1 = func1(1, 3.4) + reveal_type(v1, expected_text="float") + + v2 = func1("", 3.4) + reveal_type(v2, expected_text="str") + + v3 = func1(a, 3.4) + reveal_type(v3, expected_text="Unknown") + + v4 = func1("", a) + reveal_type(v4, expected_text="str") + diff --git a/packages/pyright-internal/src/tests/typeEvaluator4.test.ts b/packages/pyright-internal/src/tests/typeEvaluator4.test.ts index 2d2b65f22..9bd507d6a 100644 --- a/packages/pyright-internal/src/tests/typeEvaluator4.test.ts +++ b/packages/pyright-internal/src/tests/typeEvaluator4.test.ts @@ -304,6 +304,11 @@ test('Overload11', () => { TestUtils.validateResults(analysisResults, 1); }); +test('Overload12', () => { + const analysisResults = TestUtils.typeAnalyzeSampleFiles(['overload12.py']); + TestUtils.validateResults(analysisResults, 0); +}); + test('Final1', () => { const analysisResults = TestUtils.typeAnalyzeSampleFiles(['final1.py']); TestUtils.validateResults(analysisResults, 1);