Improved heuristics for when constraint solver should prefer literals over non-literals when solving a TypeVar in the case of bidirectional inference when assigning to a callable type.

This commit is contained in:
Eric Traut 2022-01-03 09:40:57 -07:00
parent 1098b379d7
commit 781b116c5a
3 changed files with 35 additions and 4 deletions

View File

@ -9069,12 +9069,17 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
effectiveExpectedType = applySolvedTypeVars(genericReturnType, tempTypeVarMap);
}
let effectiveFlags = CanAssignFlags.AllowTypeVarNarrowing;
if (containsLiteralType(effectiveExpectedType, /* includeTypeArgs */ true)) {
effectiveFlags |= CanAssignFlags.RetainLiteralsForTypeVar;
}
canAssignType(
effectiveReturnType,
effectiveExpectedType,
/* diag */ undefined,
typeVarMap,
CanAssignFlags.AllowTypeVarNarrowing | CanAssignFlags.RetainLiteralsForTypeVar
effectiveFlags
);
}
}

View File

@ -639,6 +639,24 @@ export function containsLiteralType(type: Type, includeTypeArgs = false, recursi
return type.subtypes.some((subtype) => containsLiteralType(subtype, includeTypeArgs, recursionCount + 1));
}
if (isOverloadedFunction(type)) {
return type.overloads.some((overload) => containsLiteralType(overload, includeTypeArgs, recursionCount + 1));
}
if (isFunction(type)) {
const returnType = FunctionType.getSpecializedReturnType(type);
if (returnType && containsLiteralType(returnType, includeTypeArgs, recursionCount + 1)) {
return true;
}
for (let i = 0; i < type.details.parameters.length; i++) {
const paramType = FunctionType.getEffectiveParameterType(type, i);
if (containsLiteralType(paramType, includeTypeArgs, recursionCount + 1)) {
return true;
}
}
}
return false;
}

View File

@ -4,14 +4,22 @@
# We need to validate that the type inference for lists
# is not over-narrowing when matching these literals.
from typing import List, Tuple, TypeVar
from typing import Callable, List, Tuple, TypeVar
T = TypeVar("T")
_T = TypeVar("_T")
def extend_if(xs: List[T], ys: List[Tuple[T, bool]]) -> List[T]:
def extend_if(xs: List[_T], ys: List[Tuple[_T, bool]]) -> List[_T]:
raise NotImplementedError()
extend_if(["foo"], [("bar", True), ("baz", True)])
def Return(value: _T) -> Callable[[_T], None]:
...
def func1() -> Callable[[bool], None]:
return Return(True)