Changed tuple expression inference behavior to not preserve literal entry types if the tuple expression is embedded within another tuple, set, list, or dictionary expression. This addresses #7159. (#7970)

This commit is contained in:
Eric Traut 2024-05-22 16:10:05 -07:00 committed by GitHub
parent 0618acc535
commit 50d4f44735
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 102 additions and 37 deletions

View File

@ -261,11 +261,11 @@ var1 = [4]
When inferring the type of a tuple expression (in the absence of bidirectional inference hints), Pyright assumes that the tuple has a fixed length, and each tuple element is typed as specifically as possible.
```python
# The inferred type is tuple[Literal[1], Literal["a"], Literal[True]]
# The inferred type is tuple[Literal[1], Literal["a"], Literal[True]].
var1 = (1, "a", True)
def func1(a: int):
# The inferred type is tuple[int, int]
# The inferred type is tuple[int, int].
var2 = (a, a)
# If you want the type to be tuple[int, ...]
@ -274,6 +274,13 @@ def func1(a: int):
var3: tuple[int, ...] = (a, a)
```
Because tuples are typed as specifically as possible, literal types are normally retained. However, as an exception to this inference rule, if the tuple expression is nested within another tuple, set, list or dictionary expression, literal types are not retained. This is done to avoid the inference of complex types (e.g. unions with many subtypes) when evaluating tuple statements with many entries.
```python
# The inferred type is list[tuple[int, str, bool]].
var4 = [(1, "a", True), (2, "b", False), (3, "c", False)]
```
#### List Expressions
When inferring the type of a list expression (in the absence of bidirectional inference hints), Pyright uses the following heuristics:

View File

@ -1175,7 +1175,7 @@ export function createTypeEvaluator(
}
case ParseNodeType.ListComprehension: {
typeResult = getTypeOfListComprehension(node, inferenceContext);
typeResult = getTypeOfListComprehension(node, flags, inferenceContext);
break;
}
@ -7826,6 +7826,12 @@ export function createTypeEvaluator(
return { type: makeTupleObject([]), isEmptyTupleShorthand: true };
}
flags &= ~(
EvaluatorFlags.ExpectingTypeAnnotation |
EvaluatorFlags.EvaluateStringLiteralAsType |
EvaluatorFlags.ExpectingInstantiableType
);
// If the expected type is a union, recursively call for each of the subtypes
// to find one that matches.
let effectiveExpectedType = inferenceContext?.expectedType;
@ -7845,6 +7851,7 @@ export function createTypeEvaluator(
const subtypeResult = useSpeculativeMode(node, () => {
return getTypeOfTupleWithContext(
node,
flags,
makeInferenceContext(subtype),
/* signatureTracker */ undefined
);
@ -7865,6 +7872,7 @@ export function createTypeEvaluator(
if (effectiveExpectedType) {
const result = getTypeOfTupleWithContext(
node,
flags,
makeInferenceContext(effectiveExpectedType),
signatureTracker
);
@ -7876,7 +7884,7 @@ export function createTypeEvaluator(
expectedTypeDiagAddendum = result?.expectedTypeDiagAddendum;
}
const typeResult = getTypeOfTupleInferred(node);
const typeResult = getTypeOfTupleInferred(node, flags);
// If there was an expected type of Any, replace the resulting type
// with Any rather than return a type with unknowns.
@ -7889,6 +7897,7 @@ export function createTypeEvaluator(
function getTypeOfTupleWithContext(
node: TupleNode,
flags: EvaluatorFlags,
inferenceContext: InferenceContext,
signatureTracker: UniqueSignatureTracker | undefined
): TypeResult | undefined {
@ -7947,7 +7956,7 @@ export function createTypeEvaluator(
const entryTypeResults = node.expressions.map((expr, index) =>
getTypeOfExpression(
expr,
/* flags */ undefined,
flags | EvaluatorFlags.StripLiteralTypeForTuple,
makeInferenceContext(
index < expectedTypes.length ? expectedTypes[index] : undefined,
inferenceContext.isTypeIncomplete
@ -7956,7 +7965,7 @@ export function createTypeEvaluator(
)
);
const isIncomplete = entryTypeResults.some((result) => result.isIncomplete);
const type = makeTupleObject(buildTupleTypesList(entryTypeResults));
const type = makeTupleObject(buildTupleTypesList(entryTypeResults, /* stripLiterals */ false));
// Copy any expected type diag addenda for precision error reporting.
let expectedTypeDiagAddendum: DiagnosticAddendum | undefined;
@ -7972,11 +7981,15 @@ export function createTypeEvaluator(
return { type, expectedTypeDiagAddendum, isIncomplete };
}
function getTypeOfTupleInferred(node: TupleNode): TypeResult {
const entryTypeResults = node.expressions.map((expr) => getTypeOfExpression(expr));
function getTypeOfTupleInferred(node: TupleNode, flags: EvaluatorFlags): TypeResult {
const entryTypeResults = node.expressions.map((expr) =>
getTypeOfExpression(expr, flags | EvaluatorFlags.StripLiteralTypeForTuple)
);
const isIncomplete = entryTypeResults.some((result) => result.isIncomplete);
const type = makeTupleObject(buildTupleTypesList(entryTypeResults));
const type = makeTupleObject(
buildTupleTypesList(entryTypeResults, (flags & EvaluatorFlags.StripLiteralTypeForTuple) !== 0)
);
if (isIncomplete) {
if (getContainerDepth(type) > maxInferredContainerDepth) {
@ -7987,7 +8000,7 @@ export function createTypeEvaluator(
return { type, isIncomplete };
}
function buildTupleTypesList(entryTypeResults: TypeResult[]): TupleTypeArgument[] {
function buildTupleTypesList(entryTypeResults: TypeResult[], stripLiterals: boolean): TupleTypeArgument[] {
const entryTypes: TupleTypeArgument[] = [];
for (const typeResult of entryTypeResults) {
@ -8017,7 +8030,8 @@ export function createTypeEvaluator(
} else if (isNever(typeResult.type) && typeResult.isIncomplete && !typeResult.unpackedType) {
entryTypes.push({ type: UnknownType.create(/* isIncomplete */ true), isUnbounded: false });
} else {
entryTypes.push({ type: typeResult.type, isUnbounded: !!typeResult.unpackedType });
const entryType = stripLiterals ? stripLiteralValue(typeResult.type) : typeResult.type;
entryTypes.push({ type: entryType, isUnbounded: !!typeResult.unpackedType });
}
}
@ -13318,7 +13332,7 @@ export function createTypeEvaluator(
}
const subtypeResult = useSpeculativeMode(node, () => {
return getTypeOfDictionaryWithContext(node, makeInferenceContext(subtype));
return getTypeOfDictionaryWithContext(node, flags, makeInferenceContext(subtype));
});
if (subtypeResult && assignType(subtype, subtypeResult.type)) {
@ -13341,6 +13355,7 @@ export function createTypeEvaluator(
expectedTypeDiagAddendum = new DiagnosticAddendum();
const result = getTypeOfDictionaryWithContext(
node,
flags,
makeInferenceContext(effectiveExpectedType),
expectedTypeDiagAddendum
);
@ -13349,12 +13364,13 @@ export function createTypeEvaluator(
}
}
const result = getTypeOfDictionaryInferred(node, /* hasExpectedType */ !!inferenceContext);
const result = getTypeOfDictionaryInferred(node, flags, /* hasExpectedType */ !!inferenceContext);
return { ...result, expectedTypeDiagAddendum };
}
function getTypeOfDictionaryWithContext(
node: DictionaryNode,
flags: EvaluatorFlags,
inferenceContext: InferenceContext,
expectedDiagAddendum?: DiagnosticAddendum
): TypeResult | undefined {
@ -13381,6 +13397,7 @@ export function createTypeEvaluator(
// Infer the key and value types if possible.
const keyValueTypeResult = getKeyAndValueTypesFromDictionary(
node,
flags,
keyTypes,
valueTypes,
/* forceStrictInference */ true,
@ -13472,6 +13489,7 @@ export function createTypeEvaluator(
// Infer the key and value types if possible.
const keyValueResult = getKeyAndValueTypesFromDictionary(
node,
flags,
keyTypes,
valueTypes,
/* forceStrictInference */ true,
@ -13510,7 +13528,11 @@ export function createTypeEvaluator(
// Attempts to infer the type of a dictionary statement. If hasExpectedType
// is true, strict inference is used for the subexpressions.
function getTypeOfDictionaryInferred(node: DictionaryNode, hasExpectedType: boolean): TypeResult {
function getTypeOfDictionaryInferred(
node: DictionaryNode,
flags: EvaluatorFlags,
hasExpectedType: boolean
): TypeResult {
const fallbackType = hasExpectedType ? AnyType.create() : UnknownType.create();
let keyType: Type = fallbackType;
let valueType: Type = fallbackType;
@ -13525,6 +13547,7 @@ export function createTypeEvaluator(
// Infer the key and value types if possible.
const keyValueResult = getKeyAndValueTypesFromDictionary(
node,
flags,
keyTypeResults,
valueTypeResults,
/* forceStrictInference */ hasExpectedType,
@ -13586,6 +13609,7 @@ export function createTypeEvaluator(
function getKeyAndValueTypesFromDictionary(
node: DictionaryNode,
flags: EvaluatorFlags,
keyTypes: TypeResultWithNode[],
valueTypes: TypeResultWithNode[],
forceStrictInference: boolean,
@ -13598,6 +13622,16 @@ export function createTypeEvaluator(
let isIncomplete = false;
let typeErrors = false;
// Mask out some of the flags that are not applicable for a dictionary key
// even if it appears within an inlined TypedDict annotation.
const keyFlags =
flags &
~(
EvaluatorFlags.ExpectingTypeAnnotation |
EvaluatorFlags.EvaluateStringLiteralAsType |
EvaluatorFlags.ExpectingInstantiableType
);
// Infer the key and value types if possible.
node.entries.forEach((entryNode, index) => {
let addUnknown = true;
@ -13605,7 +13639,7 @@ export function createTypeEvaluator(
if (entryNode.nodeType === ParseNodeType.DictionaryKeyEntry) {
const keyTypeResult = getTypeOfExpression(
entryNode.keyExpression,
/* flags */ undefined,
keyFlags | EvaluatorFlags.StripLiteralTypeForTuple,
makeInferenceContext(
expectedKeyType ?? (forceStrictInference ? NeverType.createNever() : undefined)
)
@ -13645,7 +13679,7 @@ export function createTypeEvaluator(
entryInferenceContext = makeInferenceContext(effectiveValueType);
valueTypeResult = getTypeOfExpression(
entryNode.valueExpression,
/* flags */ undefined,
flags | EvaluatorFlags.StripLiteralTypeForTuple,
entryInferenceContext
);
} else {
@ -13654,7 +13688,7 @@ export function createTypeEvaluator(
entryInferenceContext = makeInferenceContext(effectiveValueType);
valueTypeResult = getTypeOfExpression(
entryNode.valueExpression,
/* flags */ undefined,
flags | EvaluatorFlags.StripLiteralTypeForTuple,
entryInferenceContext
);
}
@ -13717,7 +13751,7 @@ export function createTypeEvaluator(
const entryInferenceContext = makeInferenceContext(expectedType);
let unexpandedTypeResult = getTypeOfExpression(
entryNode.expandExpression,
/* flags */ undefined,
flags | EvaluatorFlags.StripLiteralTypeForTuple,
entryInferenceContext
);
@ -13817,6 +13851,7 @@ export function createTypeEvaluator(
} else if (entryNode.nodeType === ParseNodeType.ListComprehension) {
const dictEntryTypeResult = getElementTypeFromListComprehension(
entryNode,
flags,
expectedValueType,
expectedKeyType
);
@ -13868,6 +13903,12 @@ export function createTypeEvaluator(
addDiagnostic(DiagnosticRule.reportInvalidTypeForm, LocMessage.listInAnnotation() + diag.getString(), node);
}
flags &= ~(
EvaluatorFlags.ExpectingTypeAnnotation |
EvaluatorFlags.EvaluateStringLiteralAsType |
EvaluatorFlags.ExpectingInstantiableType
);
// If the expected type is a union, recursively call for each of the subtypes
// to find one that matches.
let effectiveExpectedType = inferenceContext?.expectedType;
@ -13885,7 +13926,7 @@ export function createTypeEvaluator(
}
const subtypeResult = useSpeculativeMode(node, () => {
return getTypeOfListOrSetWithContext(node, makeInferenceContext(subtype));
return getTypeOfListOrSetWithContext(node, flags, makeInferenceContext(subtype));
});
if (subtypeResult && assignType(subtype, subtypeResult.type)) {
@ -13905,7 +13946,7 @@ export function createTypeEvaluator(
let expectedTypeDiagAddendum: DiagnosticAddendum | undefined;
if (effectiveExpectedType) {
const result = getTypeOfListOrSetWithContext(node, makeInferenceContext(effectiveExpectedType));
const result = getTypeOfListOrSetWithContext(node, flags, makeInferenceContext(effectiveExpectedType));
if (result && !result.typeErrors) {
return result;
}
@ -13913,7 +13954,11 @@ export function createTypeEvaluator(
expectedTypeDiagAddendum = result?.expectedTypeDiagAddendum;
}
const typeResult = getTypeOfListOrSetInferred(node, /* hasExpectedType */ inferenceContext !== undefined);
const typeResult = getTypeOfListOrSetInferred(
node,
flags,
/* hasExpectedType */ inferenceContext !== undefined
);
return { ...typeResult, expectedTypeDiagAddendum };
}
@ -13921,6 +13966,7 @@ export function createTypeEvaluator(
// Returns undefined if that type cannot be honored.
function getTypeOfListOrSetWithContext(
node: ListNode | SetNode,
flags: EvaluatorFlags,
inferenceContext: InferenceContext
): TypeResult | undefined {
const builtInClassName = node.nodeType === ParseNodeType.List ? 'list' : 'set';
@ -13945,11 +13991,11 @@ export function createTypeEvaluator(
let entryTypeResult: TypeResult;
if (entry.nodeType === ParseNodeType.ListComprehension) {
entryTypeResult = getElementTypeFromListComprehension(entry, expectedEntryType);
entryTypeResult = getElementTypeFromListComprehension(entry, flags, expectedEntryType);
} else {
entryTypeResult = getTypeOfExpression(
entry,
/* flags */ undefined,
flags | EvaluatorFlags.StripLiteralTypeForTuple,
makeInferenceContext(expectedEntryType)
);
}
@ -14044,7 +14090,11 @@ export function createTypeEvaluator(
}
// Attempts to infer the type of a list or set statement with no "expected type".
function getTypeOfListOrSetInferred(node: ListNode | SetNode, hasExpectedType: boolean): TypeResult {
function getTypeOfListOrSetInferred(
node: ListNode | SetNode,
flags: EvaluatorFlags,
hasExpectedType: boolean
): TypeResult {
const builtInClassName = node.nodeType === ParseNodeType.List ? 'list' : 'set';
const verifyHashable = node.nodeType === ParseNodeType.Set;
let isEmptyContainer = false;
@ -14056,9 +14106,9 @@ export function createTypeEvaluator(
let entryTypeResult: TypeResult;
if (entry.nodeType === ParseNodeType.ListComprehension && !entry.isGenerator) {
entryTypeResult = getElementTypeFromListComprehension(entry);
entryTypeResult = getElementTypeFromListComprehension(entry, flags);
} else {
entryTypeResult = getTypeOfExpression(entry);
entryTypeResult = getTypeOfExpression(entry, flags | EvaluatorFlags.StripLiteralTypeForTuple);
}
if (entryTypeResult.isIncomplete) {
@ -14479,7 +14529,11 @@ export function createTypeEvaluator(
return { type: functionType, isIncomplete, typeErrors };
}
function getTypeOfListComprehension(node: ListComprehensionNode, inferenceContext?: InferenceContext): TypeResult {
function getTypeOfListComprehension(
node: ListComprehensionNode,
flags: EvaluatorFlags,
inferenceContext?: InferenceContext
): TypeResult {
let isIncomplete = false;
let typeErrors = false;
@ -14501,7 +14555,7 @@ export function createTypeEvaluator(
const builtInIteratorType = getTypingType(node, isAsync ? 'AsyncGenerator' : 'Generator');
const expectedEntryType = getExpectedEntryTypeForIterable(node, builtInIteratorType, inferenceContext);
const elementTypeResult = getElementTypeFromListComprehension(node, expectedEntryType);
const elementTypeResult = getElementTypeFromListComprehension(node, flags, expectedEntryType);
if (elementTypeResult.isIncomplete) {
isIncomplete = true;
}
@ -14610,6 +14664,7 @@ export function createTypeEvaluator(
// as opposed to the entire list.
function getElementTypeFromListComprehension(
node: ListComprehensionNode,
flags: EvaluatorFlags,
expectedValueOrElementType?: Type,
expectedKeyType?: Type
): TypeResult {
@ -14628,7 +14683,7 @@ export function createTypeEvaluator(
// Create a tuple with the key/value types.
const keyTypeResult = getTypeOfExpression(
node.expression.keyExpression,
/* flags */ undefined,
flags | EvaluatorFlags.StripLiteralTypeForTuple,
makeInferenceContext(expectedKeyType)
);
if (keyTypeResult.isIncomplete) {
@ -14644,7 +14699,7 @@ export function createTypeEvaluator(
const valueTypeResult = getTypeOfExpression(
node.expression.valueExpression,
/* flags */ undefined,
flags | EvaluatorFlags.StripLiteralTypeForTuple,
makeInferenceContext(expectedValueOrElementType)
);
if (valueTypeResult.isIncomplete) {
@ -14666,13 +14721,13 @@ export function createTypeEvaluator(
// The parser should have reported an error in this case because it's not allowed.
getTypeOfExpression(
node.expression.expandExpression,
/* flags */ undefined,
flags | EvaluatorFlags.StripLiteralTypeForTuple,
makeInferenceContext(expectedValueOrElementType)
);
} else if (isExpressionNode(node)) {
const exprTypeResult = getTypeOfExpression(
node.expression as ExpressionNode,
/* flags */ undefined,
flags | EvaluatorFlags.StripLiteralTypeForTuple,
makeInferenceContext(expectedValueOrElementType)
);
if (exprTypeResult.isIncomplete) {

View File

@ -154,6 +154,10 @@ export const enum EvaluatorFlags {
// Allow use of the Concatenate special form.
AllowConcatenate = 1 << 27,
// Do not infer literal types within a tuple (used for tuples nested within
// other container classes).
StripLiteralTypeForTuple = 1 << 28,
// Defaults used for evaluating the LHS of a call expression.
CallBaseDefaults = DoNotSpecialize,

View File

@ -15,4 +15,4 @@ times = [
for meridian in ("am", "pm")
)
]
reveal_type(times, expected_text="list[tuple[int, int, Literal['am', 'pm']]]")
reveal_type(times, expected_text="list[tuple[int, int, str]]")

View File

@ -1,8 +1,7 @@
# This sample tests the inferred type of async and sync generators.
async def foo() -> int:
...
async def foo() -> int: ...
async def main() -> None:
@ -19,7 +18,7 @@ async def main() -> None:
reveal_type(v4, expected_text="AsyncGenerator[int, None]")
v5 = ((0, await foo()) for _ in [1, 2])
reveal_type(v5, expected_text="AsyncGenerator[tuple[Literal[0], int], None]")
reveal_type(v5, expected_text="AsyncGenerator[tuple[int, int], None]")
v6 = (x for x in [1, 2] if (x, await foo()))
reveal_type(v6, expected_text="AsyncGenerator[int, None]")

View File

@ -742,7 +742,7 @@ test('TypedDictInline1', () => {
configOptions.diagnosticRuleSet.enableExperimentalFeatures = true;
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['typedDictInline1.py'], configOptions);
TestUtils.validateResults(analysisResults, 8);
TestUtils.validateResults(analysisResults, 9);
});
test('ClassVar1', () => {