diff --git a/packages/pyright-internal/src/analyzer/typeEvaluator.ts b/packages/pyright-internal/src/analyzer/typeEvaluator.ts index cafa4ad35..40528de0d 100644 --- a/packages/pyright-internal/src/analyzer/typeEvaluator.ts +++ b/packages/pyright-internal/src/analyzer/typeEvaluator.ts @@ -396,6 +396,7 @@ interface AliasMapEntry { alias: string; module: 'builtins' | 'collections' | 'self'; isSpecialForm?: boolean; + typeParamVariance?: Variance; } interface AssignClassToSelfInfo { @@ -15547,6 +15548,20 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions specialClassType.details.flags |= ClassTypeFlags.SpecialFormClass; } + // Synthesize a single type parameter with the specified variance if + // specified in the alias map entry. + if (aliasMapEntry.typeParamVariance !== undefined) { + let typeParam = TypeVarType.createInstance('T'); + typeParam = TypeVarType.cloneForScopeId( + typeParam, + ParseTreeUtils.getScopeIdForNode(node), + assignedName, + TypeVarScopeType.Class + ); + typeParam.details.declaredVariance = aliasMapEntry.typeParamVariance; + specialClassType.details.typeParameters.push(typeParam); + } + const specialBuiltInClassDeclaration = (AnalyzerNodeInfo.getDeclaration(node) ?? (node.parent ? AnalyzerNodeInfo.getDeclaration(node.parent) : undefined)) as | SpecialBuiltInClassDeclaration @@ -15629,7 +15644,10 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions ['Annotated', { alias: '', module: 'builtins', isSpecialForm: true }], ['TypeAlias', { alias: '', module: 'builtins', isSpecialForm: true }], ['Concatenate', { alias: '', module: 'builtins', isSpecialForm: true }], - ['TypeGuard', { alias: '', module: 'builtins', isSpecialForm: true }], + [ + 'TypeGuard', + { alias: '', module: 'builtins', isSpecialForm: true, typeParamVariance: Variance.Covariant }, + ], ['Unpack', { alias: '', module: 'builtins', isSpecialForm: true }], ['Required', { alias: '', module: 'builtins', isSpecialForm: true }], ['NotRequired', { alias: '', module: 'builtins', isSpecialForm: true }], @@ -15638,7 +15656,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions ['Never', { alias: '', module: 'builtins', isSpecialForm: true }], ['LiteralString', { alias: '', module: 'builtins', isSpecialForm: true }], ['ReadOnly', { alias: '', module: 'builtins', isSpecialForm: true }], - ['TypeIs', { alias: '', module: 'builtins', isSpecialForm: true }], + ['TypeIs', { alias: '', module: 'builtins', isSpecialForm: true, typeParamVariance: Variance.Invariant }], ]); const aliasMapEntry = specialTypes.get(assignedName); @@ -23522,7 +23540,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions return !isLiteral; } } else if (ClassType.isBuiltIn(destType, ['TypeGuard', 'TypeIs'])) { - // All the source to be a "bool". + // Allow the source to be a "bool". if ((originalFlags & AssignTypeFlags.AllowBoolTypeGuard) !== 0) { if (isClassInstance(srcType) && ClassType.isBuiltIn(srcType, 'bool')) { return true; diff --git a/packages/pyright-internal/src/tests/samples/typeGuard1.py b/packages/pyright-internal/src/tests/samples/typeGuard1.py index 0525f5fac..41e9160ad 100644 --- a/packages/pyright-internal/src/tests/samples/typeGuard1.py +++ b/packages/pyright-internal/src/tests/samples/typeGuard1.py @@ -5,7 +5,7 @@ # pyright: reportMissingModuleSource=false import os -from typing import Any, TypeVar +from typing import Any, Callable, TypeVar from typing_extensions import TypeGuard # pyright: ignore[reportMissingModuleSource] _T = TypeVar("_T") @@ -100,3 +100,26 @@ def func4(typ: type[_T]) -> _T: raise Exception("Unsupported type") return typ() + + +def takes_int_typeguard(f: Callable[[object], TypeGuard[int]]) -> None: + pass + + +def int_typeguard(val: object) -> TypeGuard[int]: + return isinstance(val, int) + + +def bool_typeguard(val: object) -> TypeGuard[bool]: + return isinstance(val, bool) + + +def str_typeguard(val: object) -> TypeGuard[str]: + return isinstance(val, str) + + +takes_int_typeguard(int_typeguard) +takes_int_typeguard(bool_typeguard) + +# This should generate an error because TypeGuard is covariant. +takes_int_typeguard(str_typeguard) diff --git a/packages/pyright-internal/src/tests/samples/typeIs1.py b/packages/pyright-internal/src/tests/samples/typeIs1.py index 3a8bc7805..4d89afe40 100644 --- a/packages/pyright-internal/src/tests/samples/typeIs1.py +++ b/packages/pyright-internal/src/tests/samples/typeIs1.py @@ -1,6 +1,6 @@ # This sample tests the TypeIs form. -from typing import Any, Literal, Mapping, Sequence, TypeVar, Union +from typing import Any, Callable, Literal, Mapping, Sequence, TypeVar, Union from typing_extensions import TypeIs # pyright: ignore[reportMissingModuleSource] @@ -88,7 +88,23 @@ def is_marsupial(val: Animal) -> TypeIs[Kangaroo | Koala]: # This should generate an error because list[T] isn't consistent with list[T | None]. -def has_no_nones( - val: list[T | None], -) -> TypeIs[list[T]]: +def has_no_nones(val: list[T | None]) -> TypeIs[list[T]]: return None not in val + + +def takes_int_typeis(f: Callable[[object], TypeIs[int]]) -> None: + pass + + +def int_typeis(val: object) -> TypeIs[int]: + return isinstance(val, int) + + +def bool_typeis(val: object) -> TypeIs[bool]: + return isinstance(val, bool) + + +takes_int_typeis(int_typeis) + +# This should generate an error because TypeIs is invariant. +takes_int_typeis(bool_typeis) diff --git a/packages/pyright-internal/src/tests/typeEvaluator3.test.ts b/packages/pyright-internal/src/tests/typeEvaluator3.test.ts index 3d67d43ad..cc3da5afd 100644 --- a/packages/pyright-internal/src/tests/typeEvaluator3.test.ts +++ b/packages/pyright-internal/src/tests/typeEvaluator3.test.ts @@ -1013,7 +1013,7 @@ test('EnumGenNextValue1', () => { test('TypeGuard1', () => { const analysisResults = TestUtils.typeAnalyzeSampleFiles(['typeGuard1.py']); - TestUtils.validateResults(analysisResults, 7); + TestUtils.validateResults(analysisResults, 8); }); test('TypeGuard2', () => { @@ -1029,7 +1029,7 @@ test('TypeGuard3', () => { test('TypeIs1', () => { const analysisResults = TestUtils.typeAnalyzeSampleFiles(['typeIs1.py']); - TestUtils.validateResults(analysisResults, 1); + TestUtils.validateResults(analysisResults, 2); }); test('Never1', () => {