Modified behavior of overload matching when unpacked argument is present and the unpacked iterator is a type that doesn't provide any length information. In this case, overload matching will prefer an overload that includes a *args parameter rather than individual positional parameters.

This commit is contained in:
Eric Traut 2021-12-13 13:43:48 -08:00
parent 211ebe44dd
commit c9647b9a03
3 changed files with 148 additions and 46 deletions

View File

@ -298,11 +298,18 @@ interface EffectiveTypeCacheEntry {
}
interface MatchArgsToParamsResult {
overload: FunctionType;
overloadIndex: number;
argumentErrors: boolean;
argParams: ValidateArgTypeParams[];
activeParam?: FunctionParameter | undefined;
paramSpecTarget?: TypeVarType | undefined;
paramSpecArgList?: FunctionArgument[] | undefined;
// A higher relevance means that it should be considered
// first, before lower relevance overloads.
relevance: number;
}
interface ArgResult {
@ -6536,7 +6543,6 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
function validateOverloadsWithExpandedTypes(
errorNode: ExpressionNode,
expandedArgTypes: (Type | undefined)[][],
overloads: FunctionType[],
argParamMatches: MatchArgsToParamsResult[],
typeVarMap: TypeVarMap | undefined,
skipUnknownArgCheck: boolean,
@ -6555,8 +6561,8 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
const argTypeOverride = expandedArgTypes[expandedTypesIndex];
const hasArgTypeOverride = argTypeOverride.some((a) => a !== undefined);
for (let overloadIndex = 0; overloadIndex < overloads.length; overloadIndex++) {
const overload = overloads[overloadIndex];
for (let overloadIndex = 0; overloadIndex < argParamMatches.length; overloadIndex++) {
const overload = argParamMatches[overloadIndex].overload;
let matchResults = argParamMatches[overloadIndex];
if (hasArgTypeOverride) {
@ -6583,7 +6589,6 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
return validateFunctionArgumentTypes(
errorNode,
matchResults,
overload,
effectiveTypeVarMap,
/* skipUnknownArgCheck */ true,
expectedType
@ -6620,7 +6625,6 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
return validateFunctionArgumentTypes(
errorNode,
matchResults,
overload,
typeVarMap,
/* skipUnknownArgCheck */ true,
expectedType
@ -6631,12 +6635,10 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
// And run through the first expanded argument list one more time to
// populate the type cache.
const firstExpansionOverload = matchedOverloads[0].overload;
matchedOverloads[0].typeVarMap.unlock();
const finalCallResult = validateFunctionArgumentTypes(
errorNode,
matchedOverloads[0].matchResults,
firstExpansionOverload,
matchedOverloads[0].typeVarMap,
skipUnknownArgCheck,
expectedType
@ -6659,33 +6661,63 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
type: OverloadedFunctionType,
argList: FunctionArgument[]
): FunctionType | undefined {
let firstMatch: FunctionType | undefined;
let overloadIndex = 0;
let matches: MatchArgsToParamsResult[] = [];
// Create a list of potential overload matches based on arguments.
type.overloads.forEach((overload) => {
if (!firstMatch) {
useSpeculativeMode(errorNode, () => {
if (FunctionType.isOverloaded(overload)) {
const matchResults = matchFunctionArgumentsToParameters(errorNode, argList, overload);
if (!matchResults.argumentErrors) {
const callResult = validateFunctionArgumentTypes(
errorNode,
matchResults,
overload,
new TypeVarMap(getTypeVarScopeId(overload)),
/* skipUnknownArgCheck */ true,
/* expectedType */ undefined
);
useSpeculativeMode(errorNode, () => {
if (FunctionType.isOverloaded(overload)) {
const matchResults = matchFunctionArgumentsToParameters(
errorNode,
argList,
overload,
overloadIndex
);
if (callResult && !callResult.argumentErrors) {
firstMatch = overload;
}
}
if (!matchResults.argumentErrors) {
matches.push(matchResults);
}
overloadIndex++;
}
});
});
matches = sortOverloadsByBestMatch(matches);
let winningOverloadIndex: number | undefined;
matches.forEach((match, matchIndex) => {
if (winningOverloadIndex === undefined) {
useSpeculativeMode(errorNode, () => {
const callResult = validateFunctionArgumentTypes(
errorNode,
match,
new TypeVarMap(getTypeVarScopeId(match.overload)),
/* skipUnknownArgCheck */ true,
/* expectedType */ undefined
);
if (callResult && !callResult.argumentErrors) {
winningOverloadIndex = matchIndex;
}
});
}
});
return firstMatch;
return winningOverloadIndex === undefined ? undefined : matches[winningOverloadIndex].overload;
}
// Sorts the list of overloads based first on "relevance" and second on order.
function sortOverloadsByBestMatch(matches: MatchArgsToParamsResult[]) {
return matches.sort((a, b) => {
if (a.relevance !== b.relevance) {
return b.relevance - a.relevance;
}
return a.overloadIndex - b.overloadIndex;
});
}
function validateOverloadedFunctionArguments(
@ -6696,8 +6728,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
skipUnknownArgCheck: boolean,
expectedType: Type | undefined
): CallResult {
const filteredOverloads: FunctionType[] = [];
const filteredMatchResults: MatchArgsToParamsResult[] = [];
let filteredMatchResults: MatchArgsToParamsResult[] = [];
let contextFreeArgTypes: Type[] = [];
// Start by evaluating the types of the arguments without any expected
@ -6706,16 +6737,23 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
// speculatively because we don't want to record any types in the type
// cache or record any diagnostics at this stage.
useSpeculativeMode(errorNode, () => {
let overloadIndex = 0;
type.overloads.forEach((overload) => {
// Consider only the functions that have the @overload decorator,
// not the final function that omits the overload. This is the
// intended behavior according to PEP 484.
if (FunctionType.isOverloaded(overload)) {
const matchResults = matchFunctionArgumentsToParameters(errorNode, argList, overload);
const matchResults = matchFunctionArgumentsToParameters(
errorNode,
argList,
overload,
overloadIndex
);
if (!matchResults.argumentErrors) {
filteredOverloads.push(overload);
filteredMatchResults.push(matchResults);
}
overloadIndex++;
}
});
@ -6731,6 +6769,8 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
);
});
filteredMatchResults = sortOverloadsByBestMatch(filteredMatchResults);
// If there are no possible arg/param matches among the overloads,
// emit an error that includes the argument types.
if (filteredMatchResults.length === 0) {
@ -6758,17 +6798,19 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
// Create a helper lambda that evaluates the overload that matches
// the arg/param lists.
const evaluateUsingLastMatchingOverload = (skipUnknownArgCheck: boolean) => {
const lastOverload = filteredOverloads[filteredOverloads.length - 1];
const lastMatch = filteredMatchResults[filteredOverloads.length - 1];
// Find the match with the largest overload index (i.e. the last overload
// that was in the overload list).
const lastMatch = filteredMatchResults.reduce((previous, current) => {
return current.overloadIndex > previous.overloadIndex ? current : previous;
});
const effectiveTypeVarMap = typeVarMap ?? new TypeVarMap();
effectiveTypeVarMap.addSolveForScope(getTypeVarScopeId(lastOverload));
effectiveTypeVarMap.addSolveForScope(getTypeVarScopeId(lastMatch.overload));
effectiveTypeVarMap.unlock();
return validateFunctionArgumentTypes(
errorNode,
lastMatch,
lastOverload,
effectiveTypeVarMap,
skipUnknownArgCheck,
expectedType
@ -6789,7 +6831,6 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
const callResult = validateOverloadsWithExpandedTypes(
errorNode,
expandedArgTypes,
filteredOverloads,
filteredMatchResults,
typeVarMap,
skipUnknownArgCheck,
@ -7836,10 +7877,12 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
function matchFunctionArgumentsToParameters(
errorNode: ExpressionNode,
argList: FunctionArgument[],
type: FunctionType
type: FunctionType,
overloadIndex: number
): MatchArgsToParamsResult {
let argIndex = 0;
const typeParams = type.details.parameters;
let matchedUnpackedListOfUnknownLength = false;
// The last parameter might be a var arg dictionary. If so, strip it off.
const varArgDictParam = typeParams.find((param) => param.category === ParameterCategory.VarArgDictionary);
@ -8073,6 +8116,10 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
listElementType =
getTypeFromIterator(argType, /* isAsync */ false, argList[argIndex].valueExpression!) ||
UnknownType.create();
if (typeParams[paramIndex].category !== ParameterCategory.VarArgList) {
matchedUnpackedListOfUnknownLength = true;
}
}
const funcArg: FunctionArgument | undefined = listElementType
@ -8598,12 +8645,23 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
}
}
let relevance = 0;
if (matchedUnpackedListOfUnknownLength) {
// Lower the relevance if we made assumptions about the length
// of an unpacked argument. This will favor overloads that
// associate this case with a *args parameter.
relevance--;
}
return {
overload: type,
overloadIndex,
argumentErrors: reportedArgError,
argParams: validateArgTypeParams,
paramSpecTarget,
paramSpecArgList,
activeParam,
relevance,
};
}
@ -8613,7 +8671,6 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
function validateFunctionArgumentTypes(
errorNode: ExpressionNode,
matchResults: MatchArgsToParamsResult,
type: FunctionType,
typeVarMap: TypeVarMap,
skipUnknownArgCheck = false,
expectedType?: Type
@ -8621,6 +8678,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
let isTypeIncomplete = false;
let argumentErrors = false;
let specializedInitSelfType: Type | undefined;
const type = matchResults.overload;
const typeCondition = getTypeCondition(type);
@ -8896,7 +8954,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
skipUnknownArgCheck = false,
expectedType?: Type
): CallResult {
const matchResults = matchFunctionArgumentsToParameters(errorNode, argList, type);
const matchResults = matchFunctionArgumentsToParameters(errorNode, argList, type, 0);
if (matchResults.argumentErrors) {
// Evaluate types of all args. This will ensure that referenced symbols are
@ -8915,14 +8973,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
};
}
return validateFunctionArgumentTypes(
errorNode,
matchResults,
type,
typeVarMap,
skipUnknownArgCheck,
expectedType
);
return validateFunctionArgumentTypes(errorNode, matchResults, typeVarMap, skipUnknownArgCheck, expectedType);
}
// Determines whether the specified argument list satisfies the function

View File

@ -0,0 +1,46 @@
# This sample tests an overload that provides a signature for
# a *args parameter.
from typing import Iterable, Literal, Tuple, TypeVar, overload
_T1 = TypeVar("_T1")
_T2 = TypeVar("_T2")
@overload
def func1(__iter1: Iterable[_T1]) -> Tuple[_T1]:
...
@overload
def func1(__iter1: Iterable[_T1], __iter2: Iterable[_T2]) -> Tuple[_T1, _T2]:
...
@overload
def func1(*iterables: Iterable[_T1]) -> Tuple[_T1, ...]:
...
def func1(*iterables: Iterable[_T1]) -> Tuple[_T1, ...]:
...
def func2(x: Iterable[int]):
v1 = func1(x)
t1: Literal["Tuple[int]"] = reveal_type(v1)
v2 = func1(x, x)
t2: Literal["Tuple[int, int]"] = reveal_type(v2)
y = [x, x, x, x]
v3 = func1(*y)
t3: Literal["Tuple[int, ...]"] = reveal_type(v3)
z = (x, x)
v4 = func1(*z)
t4: Literal["Tuple[int, int]"] = reveal_type(v4)

View File

@ -271,6 +271,11 @@ test('Overload9', () => {
TestUtils.validateResults(analysisResults, 1);
});
test('Overload10', () => {
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['overload10.py']);
TestUtils.validateResults(analysisResults, 0);
});
test('Final1', () => {
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['final1.py']);
TestUtils.validateResults(analysisResults, 1);