diff --git a/docs/type-inference.md b/docs/type-inference.md index 2f9833a4e..8d5b456f6 100644 --- a/docs/type-inference.md +++ b/docs/type-inference.md @@ -261,11 +261,11 @@ var1 = [4] When inferring the type of a tuple expression (in the absence of bidirectional inference hints), Pyright assumes that the tuple has a fixed length, and each tuple element is typed as specifically as possible. ```python -# The inferred type is tuple[Literal[1], Literal["a"], Literal[True]] +# The inferred type is tuple[Literal[1], Literal["a"], Literal[True]]. var1 = (1, "a", True) def func1(a: int): - # The inferred type is tuple[int, int] + # The inferred type is tuple[int, int]. var2 = (a, a) # If you want the type to be tuple[int, ...] @@ -274,6 +274,13 @@ def func1(a: int): var3: tuple[int, ...] = (a, a) ``` +Because tuples are typed as specifically as possible, literal types are normally retained. However, as an exception to this inference rule, if the tuple expression is nested within another tuple, set, list or dictionary expression, literal types are not retained. This is done to avoid the inference of complex types (e.g. unions with many subtypes) when evaluating tuple statements with many entries. + +```python +# The inferred type is list[tuple[int, str, bool]]. +var4 = [(1, "a", True), (2, "b", False), (3, "c", False)] +``` + #### List Expressions When inferring the type of a list expression (in the absence of bidirectional inference hints), Pyright uses the following heuristics: diff --git a/packages/pyright-internal/src/analyzer/typeEvaluator.ts b/packages/pyright-internal/src/analyzer/typeEvaluator.ts index 956ee1ba2..701611e08 100644 --- a/packages/pyright-internal/src/analyzer/typeEvaluator.ts +++ b/packages/pyright-internal/src/analyzer/typeEvaluator.ts @@ -1175,7 +1175,7 @@ export function createTypeEvaluator( } case ParseNodeType.ListComprehension: { - typeResult = getTypeOfListComprehension(node, inferenceContext); + typeResult = getTypeOfListComprehension(node, flags, inferenceContext); break; } @@ -7826,6 +7826,12 @@ export function createTypeEvaluator( return { type: makeTupleObject([]), isEmptyTupleShorthand: true }; } + flags &= ~( + EvaluatorFlags.ExpectingTypeAnnotation | + EvaluatorFlags.EvaluateStringLiteralAsType | + EvaluatorFlags.ExpectingInstantiableType + ); + // If the expected type is a union, recursively call for each of the subtypes // to find one that matches. let effectiveExpectedType = inferenceContext?.expectedType; @@ -7845,6 +7851,7 @@ export function createTypeEvaluator( const subtypeResult = useSpeculativeMode(node, () => { return getTypeOfTupleWithContext( node, + flags, makeInferenceContext(subtype), /* signatureTracker */ undefined ); @@ -7865,6 +7872,7 @@ export function createTypeEvaluator( if (effectiveExpectedType) { const result = getTypeOfTupleWithContext( node, + flags, makeInferenceContext(effectiveExpectedType), signatureTracker ); @@ -7876,7 +7884,7 @@ export function createTypeEvaluator( expectedTypeDiagAddendum = result?.expectedTypeDiagAddendum; } - const typeResult = getTypeOfTupleInferred(node); + const typeResult = getTypeOfTupleInferred(node, flags); // If there was an expected type of Any, replace the resulting type // with Any rather than return a type with unknowns. @@ -7889,6 +7897,7 @@ export function createTypeEvaluator( function getTypeOfTupleWithContext( node: TupleNode, + flags: EvaluatorFlags, inferenceContext: InferenceContext, signatureTracker: UniqueSignatureTracker | undefined ): TypeResult | undefined { @@ -7947,7 +7956,7 @@ export function createTypeEvaluator( const entryTypeResults = node.expressions.map((expr, index) => getTypeOfExpression( expr, - /* flags */ undefined, + flags | EvaluatorFlags.StripLiteralTypeForTuple, makeInferenceContext( index < expectedTypes.length ? expectedTypes[index] : undefined, inferenceContext.isTypeIncomplete @@ -7956,7 +7965,7 @@ export function createTypeEvaluator( ) ); const isIncomplete = entryTypeResults.some((result) => result.isIncomplete); - const type = makeTupleObject(buildTupleTypesList(entryTypeResults)); + const type = makeTupleObject(buildTupleTypesList(entryTypeResults, /* stripLiterals */ false)); // Copy any expected type diag addenda for precision error reporting. let expectedTypeDiagAddendum: DiagnosticAddendum | undefined; @@ -7972,11 +7981,15 @@ export function createTypeEvaluator( return { type, expectedTypeDiagAddendum, isIncomplete }; } - function getTypeOfTupleInferred(node: TupleNode): TypeResult { - const entryTypeResults = node.expressions.map((expr) => getTypeOfExpression(expr)); + function getTypeOfTupleInferred(node: TupleNode, flags: EvaluatorFlags): TypeResult { + const entryTypeResults = node.expressions.map((expr) => + getTypeOfExpression(expr, flags | EvaluatorFlags.StripLiteralTypeForTuple) + ); const isIncomplete = entryTypeResults.some((result) => result.isIncomplete); - const type = makeTupleObject(buildTupleTypesList(entryTypeResults)); + const type = makeTupleObject( + buildTupleTypesList(entryTypeResults, (flags & EvaluatorFlags.StripLiteralTypeForTuple) !== 0) + ); if (isIncomplete) { if (getContainerDepth(type) > maxInferredContainerDepth) { @@ -7987,7 +8000,7 @@ export function createTypeEvaluator( return { type, isIncomplete }; } - function buildTupleTypesList(entryTypeResults: TypeResult[]): TupleTypeArgument[] { + function buildTupleTypesList(entryTypeResults: TypeResult[], stripLiterals: boolean): TupleTypeArgument[] { const entryTypes: TupleTypeArgument[] = []; for (const typeResult of entryTypeResults) { @@ -8017,7 +8030,8 @@ export function createTypeEvaluator( } else if (isNever(typeResult.type) && typeResult.isIncomplete && !typeResult.unpackedType) { entryTypes.push({ type: UnknownType.create(/* isIncomplete */ true), isUnbounded: false }); } else { - entryTypes.push({ type: typeResult.type, isUnbounded: !!typeResult.unpackedType }); + const entryType = stripLiterals ? stripLiteralValue(typeResult.type) : typeResult.type; + entryTypes.push({ type: entryType, isUnbounded: !!typeResult.unpackedType }); } } @@ -13318,7 +13332,7 @@ export function createTypeEvaluator( } const subtypeResult = useSpeculativeMode(node, () => { - return getTypeOfDictionaryWithContext(node, makeInferenceContext(subtype)); + return getTypeOfDictionaryWithContext(node, flags, makeInferenceContext(subtype)); }); if (subtypeResult && assignType(subtype, subtypeResult.type)) { @@ -13341,6 +13355,7 @@ export function createTypeEvaluator( expectedTypeDiagAddendum = new DiagnosticAddendum(); const result = getTypeOfDictionaryWithContext( node, + flags, makeInferenceContext(effectiveExpectedType), expectedTypeDiagAddendum ); @@ -13349,12 +13364,13 @@ export function createTypeEvaluator( } } - const result = getTypeOfDictionaryInferred(node, /* hasExpectedType */ !!inferenceContext); + const result = getTypeOfDictionaryInferred(node, flags, /* hasExpectedType */ !!inferenceContext); return { ...result, expectedTypeDiagAddendum }; } function getTypeOfDictionaryWithContext( node: DictionaryNode, + flags: EvaluatorFlags, inferenceContext: InferenceContext, expectedDiagAddendum?: DiagnosticAddendum ): TypeResult | undefined { @@ -13381,6 +13397,7 @@ export function createTypeEvaluator( // Infer the key and value types if possible. const keyValueTypeResult = getKeyAndValueTypesFromDictionary( node, + flags, keyTypes, valueTypes, /* forceStrictInference */ true, @@ -13472,6 +13489,7 @@ export function createTypeEvaluator( // Infer the key and value types if possible. const keyValueResult = getKeyAndValueTypesFromDictionary( node, + flags, keyTypes, valueTypes, /* forceStrictInference */ true, @@ -13510,7 +13528,11 @@ export function createTypeEvaluator( // Attempts to infer the type of a dictionary statement. If hasExpectedType // is true, strict inference is used for the subexpressions. - function getTypeOfDictionaryInferred(node: DictionaryNode, hasExpectedType: boolean): TypeResult { + function getTypeOfDictionaryInferred( + node: DictionaryNode, + flags: EvaluatorFlags, + hasExpectedType: boolean + ): TypeResult { const fallbackType = hasExpectedType ? AnyType.create() : UnknownType.create(); let keyType: Type = fallbackType; let valueType: Type = fallbackType; @@ -13525,6 +13547,7 @@ export function createTypeEvaluator( // Infer the key and value types if possible. const keyValueResult = getKeyAndValueTypesFromDictionary( node, + flags, keyTypeResults, valueTypeResults, /* forceStrictInference */ hasExpectedType, @@ -13586,6 +13609,7 @@ export function createTypeEvaluator( function getKeyAndValueTypesFromDictionary( node: DictionaryNode, + flags: EvaluatorFlags, keyTypes: TypeResultWithNode[], valueTypes: TypeResultWithNode[], forceStrictInference: boolean, @@ -13598,6 +13622,16 @@ export function createTypeEvaluator( let isIncomplete = false; let typeErrors = false; + // Mask out some of the flags that are not applicable for a dictionary key + // even if it appears within an inlined TypedDict annotation. + const keyFlags = + flags & + ~( + EvaluatorFlags.ExpectingTypeAnnotation | + EvaluatorFlags.EvaluateStringLiteralAsType | + EvaluatorFlags.ExpectingInstantiableType + ); + // Infer the key and value types if possible. node.entries.forEach((entryNode, index) => { let addUnknown = true; @@ -13605,7 +13639,7 @@ export function createTypeEvaluator( if (entryNode.nodeType === ParseNodeType.DictionaryKeyEntry) { const keyTypeResult = getTypeOfExpression( entryNode.keyExpression, - /* flags */ undefined, + keyFlags | EvaluatorFlags.StripLiteralTypeForTuple, makeInferenceContext( expectedKeyType ?? (forceStrictInference ? NeverType.createNever() : undefined) ) @@ -13645,7 +13679,7 @@ export function createTypeEvaluator( entryInferenceContext = makeInferenceContext(effectiveValueType); valueTypeResult = getTypeOfExpression( entryNode.valueExpression, - /* flags */ undefined, + flags | EvaluatorFlags.StripLiteralTypeForTuple, entryInferenceContext ); } else { @@ -13654,7 +13688,7 @@ export function createTypeEvaluator( entryInferenceContext = makeInferenceContext(effectiveValueType); valueTypeResult = getTypeOfExpression( entryNode.valueExpression, - /* flags */ undefined, + flags | EvaluatorFlags.StripLiteralTypeForTuple, entryInferenceContext ); } @@ -13717,7 +13751,7 @@ export function createTypeEvaluator( const entryInferenceContext = makeInferenceContext(expectedType); let unexpandedTypeResult = getTypeOfExpression( entryNode.expandExpression, - /* flags */ undefined, + flags | EvaluatorFlags.StripLiteralTypeForTuple, entryInferenceContext ); @@ -13817,6 +13851,7 @@ export function createTypeEvaluator( } else if (entryNode.nodeType === ParseNodeType.ListComprehension) { const dictEntryTypeResult = getElementTypeFromListComprehension( entryNode, + flags, expectedValueType, expectedKeyType ); @@ -13868,6 +13903,12 @@ export function createTypeEvaluator( addDiagnostic(DiagnosticRule.reportInvalidTypeForm, LocMessage.listInAnnotation() + diag.getString(), node); } + flags &= ~( + EvaluatorFlags.ExpectingTypeAnnotation | + EvaluatorFlags.EvaluateStringLiteralAsType | + EvaluatorFlags.ExpectingInstantiableType + ); + // If the expected type is a union, recursively call for each of the subtypes // to find one that matches. let effectiveExpectedType = inferenceContext?.expectedType; @@ -13885,7 +13926,7 @@ export function createTypeEvaluator( } const subtypeResult = useSpeculativeMode(node, () => { - return getTypeOfListOrSetWithContext(node, makeInferenceContext(subtype)); + return getTypeOfListOrSetWithContext(node, flags, makeInferenceContext(subtype)); }); if (subtypeResult && assignType(subtype, subtypeResult.type)) { @@ -13905,7 +13946,7 @@ export function createTypeEvaluator( let expectedTypeDiagAddendum: DiagnosticAddendum | undefined; if (effectiveExpectedType) { - const result = getTypeOfListOrSetWithContext(node, makeInferenceContext(effectiveExpectedType)); + const result = getTypeOfListOrSetWithContext(node, flags, makeInferenceContext(effectiveExpectedType)); if (result && !result.typeErrors) { return result; } @@ -13913,7 +13954,11 @@ export function createTypeEvaluator( expectedTypeDiagAddendum = result?.expectedTypeDiagAddendum; } - const typeResult = getTypeOfListOrSetInferred(node, /* hasExpectedType */ inferenceContext !== undefined); + const typeResult = getTypeOfListOrSetInferred( + node, + flags, + /* hasExpectedType */ inferenceContext !== undefined + ); return { ...typeResult, expectedTypeDiagAddendum }; } @@ -13921,6 +13966,7 @@ export function createTypeEvaluator( // Returns undefined if that type cannot be honored. function getTypeOfListOrSetWithContext( node: ListNode | SetNode, + flags: EvaluatorFlags, inferenceContext: InferenceContext ): TypeResult | undefined { const builtInClassName = node.nodeType === ParseNodeType.List ? 'list' : 'set'; @@ -13945,11 +13991,11 @@ export function createTypeEvaluator( let entryTypeResult: TypeResult; if (entry.nodeType === ParseNodeType.ListComprehension) { - entryTypeResult = getElementTypeFromListComprehension(entry, expectedEntryType); + entryTypeResult = getElementTypeFromListComprehension(entry, flags, expectedEntryType); } else { entryTypeResult = getTypeOfExpression( entry, - /* flags */ undefined, + flags | EvaluatorFlags.StripLiteralTypeForTuple, makeInferenceContext(expectedEntryType) ); } @@ -14044,7 +14090,11 @@ export function createTypeEvaluator( } // Attempts to infer the type of a list or set statement with no "expected type". - function getTypeOfListOrSetInferred(node: ListNode | SetNode, hasExpectedType: boolean): TypeResult { + function getTypeOfListOrSetInferred( + node: ListNode | SetNode, + flags: EvaluatorFlags, + hasExpectedType: boolean + ): TypeResult { const builtInClassName = node.nodeType === ParseNodeType.List ? 'list' : 'set'; const verifyHashable = node.nodeType === ParseNodeType.Set; let isEmptyContainer = false; @@ -14056,9 +14106,9 @@ export function createTypeEvaluator( let entryTypeResult: TypeResult; if (entry.nodeType === ParseNodeType.ListComprehension && !entry.isGenerator) { - entryTypeResult = getElementTypeFromListComprehension(entry); + entryTypeResult = getElementTypeFromListComprehension(entry, flags); } else { - entryTypeResult = getTypeOfExpression(entry); + entryTypeResult = getTypeOfExpression(entry, flags | EvaluatorFlags.StripLiteralTypeForTuple); } if (entryTypeResult.isIncomplete) { @@ -14479,7 +14529,11 @@ export function createTypeEvaluator( return { type: functionType, isIncomplete, typeErrors }; } - function getTypeOfListComprehension(node: ListComprehensionNode, inferenceContext?: InferenceContext): TypeResult { + function getTypeOfListComprehension( + node: ListComprehensionNode, + flags: EvaluatorFlags, + inferenceContext?: InferenceContext + ): TypeResult { let isIncomplete = false; let typeErrors = false; @@ -14501,7 +14555,7 @@ export function createTypeEvaluator( const builtInIteratorType = getTypingType(node, isAsync ? 'AsyncGenerator' : 'Generator'); const expectedEntryType = getExpectedEntryTypeForIterable(node, builtInIteratorType, inferenceContext); - const elementTypeResult = getElementTypeFromListComprehension(node, expectedEntryType); + const elementTypeResult = getElementTypeFromListComprehension(node, flags, expectedEntryType); if (elementTypeResult.isIncomplete) { isIncomplete = true; } @@ -14610,6 +14664,7 @@ export function createTypeEvaluator( // as opposed to the entire list. function getElementTypeFromListComprehension( node: ListComprehensionNode, + flags: EvaluatorFlags, expectedValueOrElementType?: Type, expectedKeyType?: Type ): TypeResult { @@ -14628,7 +14683,7 @@ export function createTypeEvaluator( // Create a tuple with the key/value types. const keyTypeResult = getTypeOfExpression( node.expression.keyExpression, - /* flags */ undefined, + flags | EvaluatorFlags.StripLiteralTypeForTuple, makeInferenceContext(expectedKeyType) ); if (keyTypeResult.isIncomplete) { @@ -14644,7 +14699,7 @@ export function createTypeEvaluator( const valueTypeResult = getTypeOfExpression( node.expression.valueExpression, - /* flags */ undefined, + flags | EvaluatorFlags.StripLiteralTypeForTuple, makeInferenceContext(expectedValueOrElementType) ); if (valueTypeResult.isIncomplete) { @@ -14666,13 +14721,13 @@ export function createTypeEvaluator( // The parser should have reported an error in this case because it's not allowed. getTypeOfExpression( node.expression.expandExpression, - /* flags */ undefined, + flags | EvaluatorFlags.StripLiteralTypeForTuple, makeInferenceContext(expectedValueOrElementType) ); } else if (isExpressionNode(node)) { const exprTypeResult = getTypeOfExpression( node.expression as ExpressionNode, - /* flags */ undefined, + flags | EvaluatorFlags.StripLiteralTypeForTuple, makeInferenceContext(expectedValueOrElementType) ); if (exprTypeResult.isIncomplete) { diff --git a/packages/pyright-internal/src/analyzer/typeEvaluatorTypes.ts b/packages/pyright-internal/src/analyzer/typeEvaluatorTypes.ts index 1e55bb40a..90bfe9553 100644 --- a/packages/pyright-internal/src/analyzer/typeEvaluatorTypes.ts +++ b/packages/pyright-internal/src/analyzer/typeEvaluatorTypes.ts @@ -154,6 +154,10 @@ export const enum EvaluatorFlags { // Allow use of the Concatenate special form. AllowConcatenate = 1 << 27, + // Do not infer literal types within a tuple (used for tuples nested within + // other container classes). + StripLiteralTypeForTuple = 1 << 28, + // Defaults used for evaluating the LHS of a call expression. CallBaseDefaults = DoNotSpecialize, diff --git a/packages/pyright-internal/src/tests/samples/comprehension11.py b/packages/pyright-internal/src/tests/samples/comprehension11.py index 311c146f4..34717f076 100644 --- a/packages/pyright-internal/src/tests/samples/comprehension11.py +++ b/packages/pyright-internal/src/tests/samples/comprehension11.py @@ -15,4 +15,4 @@ times = [ for meridian in ("am", "pm") ) ] -reveal_type(times, expected_text="list[tuple[int, int, Literal['am', 'pm']]]") +reveal_type(times, expected_text="list[tuple[int, int, str]]") diff --git a/packages/pyright-internal/src/tests/samples/generator14.py b/packages/pyright-internal/src/tests/samples/generator14.py index b923bc360..65df1d3cb 100644 --- a/packages/pyright-internal/src/tests/samples/generator14.py +++ b/packages/pyright-internal/src/tests/samples/generator14.py @@ -1,8 +1,7 @@ # This sample tests the inferred type of async and sync generators. -async def foo() -> int: - ... +async def foo() -> int: ... async def main() -> None: @@ -19,7 +18,7 @@ async def main() -> None: reveal_type(v4, expected_text="AsyncGenerator[int, None]") v5 = ((0, await foo()) for _ in [1, 2]) - reveal_type(v5, expected_text="AsyncGenerator[tuple[Literal[0], int], None]") + reveal_type(v5, expected_text="AsyncGenerator[tuple[int, int], None]") v6 = (x for x in [1, 2] if (x, await foo())) reveal_type(v6, expected_text="AsyncGenerator[int, None]") diff --git a/packages/pyright-internal/src/tests/typeEvaluator7.test.ts b/packages/pyright-internal/src/tests/typeEvaluator7.test.ts index fd42f7699..1e36abef4 100644 --- a/packages/pyright-internal/src/tests/typeEvaluator7.test.ts +++ b/packages/pyright-internal/src/tests/typeEvaluator7.test.ts @@ -742,7 +742,7 @@ test('TypedDictInline1', () => { configOptions.diagnosticRuleSet.enableExperimentalFeatures = true; const analysisResults = TestUtils.typeAnalyzeSampleFiles(['typedDictInline1.py'], configOptions); - TestUtils.validateResults(analysisResults, 8); + TestUtils.validateResults(analysisResults, 9); }); test('ClassVar1', () => {