diff --git a/packages/pyright-internal/src/analyzer/typeEvaluator.ts b/packages/pyright-internal/src/analyzer/typeEvaluator.ts index e0971032b..00549d926 100644 --- a/packages/pyright-internal/src/analyzer/typeEvaluator.ts +++ b/packages/pyright-internal/src/analyzer/typeEvaluator.ts @@ -19130,6 +19130,20 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions } } + // Handle the special case where the dest is a union of Any and + // a type variable and CanAssignFlags.AllowTypeVarNarrowing is + // in effect. This occurs, for example, with the return type of + // the getattr function. + if ((flags & CanAssignFlags.AllowTypeVarNarrowing) !== 0 && isUnion(destType)) { + const nonAnySubtypes = destType.subtypes.filter((t) => !isAnyOrUnknown(t)); + if (nonAnySubtypes.length === 1 && isTypeVar(nonAnySubtypes[0])) { + canAssignType(nonAnySubtypes[0], srcType, /* diag */ undefined, typeVarMap, flags, recursionCount + 1); + + // This always succeeds because the destination contains Any. + return true; + } + } + // For union sources, all of the types need to be assignable to the dest. let isIncompatible = false; doForEachSubtype(srcType, (subtype) => { diff --git a/packages/pyright-internal/src/tests/samples/genericTypes72.py b/packages/pyright-internal/src/tests/samples/genericTypes72.py new file mode 100644 index 000000000..4cbe80a46 --- /dev/null +++ b/packages/pyright-internal/src/tests/samples/genericTypes72.py @@ -0,0 +1,15 @@ +# This sample tests a special case of bidirectional type inference when +# the expected type is a union and the destination type is a union that +# contains Any and a TypeVar. + + +from typing import Any, Literal, TypeVar + +_T = TypeVar("_T") + + +def getattr(__o: object, name: str, __default: _T) -> Any | _T: + ... + + +x: Literal[1, 2, 3] = getattr(object(), "", 1) diff --git a/packages/pyright-internal/src/tests/typeEvaluator2.test.ts b/packages/pyright-internal/src/tests/typeEvaluator2.test.ts index b1c4064bb..83c70c3de 100644 --- a/packages/pyright-internal/src/tests/typeEvaluator2.test.ts +++ b/packages/pyright-internal/src/tests/typeEvaluator2.test.ts @@ -757,6 +757,12 @@ test('GenericTypes71', () => { TestUtils.validateResults(analysisResults, 4); }); +test('GenericTypes72', () => { + const analysisResults = TestUtils.typeAnalyzeSampleFiles(['genericTypes72.py']); + + TestUtils.validateResults(analysisResults, 0); +}); + test('Protocol1', () => { const analysisResults = TestUtils.typeAnalyzeSampleFiles(['protocol1.py']);