From 81f21ee3d4ada50505c6550ba38d9bda68d7753e Mon Sep 17 00:00:00 2001 From: Eric Traut Date: Wed, 1 Mar 2023 23:13:01 -0700 Subject: [PATCH] Added support for type guard patterns `type(x) == T` and `type(x) != T`. This addresses https://github.com/microsoft/pyright/issues/4719. --- docs/type-concepts.md | 1 + .../src/analyzer/typeGuards.ts | 4 +- .../tests/samples/typeNarrowingTypeEquals1.py | 110 ++++++++++++++++++ .../src/tests/typeEvaluator1.test.ts | 6 + 4 files changed, 119 insertions(+), 2 deletions(-) create mode 100644 packages/pyright-internal/src/tests/samples/typeNarrowingTypeEquals1.py diff --git a/docs/type-concepts.md b/docs/type-concepts.md index f4cbc40b3..673ddca2e 100644 --- a/docs/type-concepts.md +++ b/docs/type-concepts.md @@ -171,6 +171,7 @@ In addition to assignment-based type narrowing, Pyright supports the following t * `x is ...` and `x is not ...` * `x == ...` and `x != ...` * `type(x) is T` and `type(x) is not T` +* `type(x) == T` and `type(x) != T` * `x is E` and `x is not E` (where E is a literal enum or bool) * `x == L` and `x != L` (where L is an expression that evaluates to a literal type) * `x.y is None` and `x.y is not None` (where x is a type that is distinguished by a field with a None) diff --git a/packages/pyright-internal/src/analyzer/typeGuards.ts b/packages/pyright-internal/src/analyzer/typeGuards.ts index 5025a6d5c..d942a40a3 100644 --- a/packages/pyright-internal/src/analyzer/typeGuards.ts +++ b/packages/pyright-internal/src/analyzer/typeGuards.ts @@ -191,8 +191,8 @@ export function getTypeNarrowingCallback( } } - // Look for "type(X) is Y" or "type(X) is not Y". - if (isOrIsNotOperator && testExpression.leftExpression.nodeType === ParseNodeType.Call) { + // Look for "type(X) is Y", "type(X) is not Y", "type(X) == Y" or "type(X) != Y". + if (testExpression.leftExpression.nodeType === ParseNodeType.Call) { if ( testExpression.leftExpression.arguments.length === 1 && testExpression.leftExpression.arguments[0].argumentCategory === ArgumentCategory.Simple diff --git a/packages/pyright-internal/src/tests/samples/typeNarrowingTypeEquals1.py b/packages/pyright-internal/src/tests/samples/typeNarrowingTypeEquals1.py new file mode 100644 index 000000000..804d681a4 --- /dev/null +++ b/packages/pyright-internal/src/tests/samples/typeNarrowingTypeEquals1.py @@ -0,0 +1,110 @@ +# This sample exercises the type analyzer's type narrowing +# logic for tests of the form "type(X) == Y" or "type(X) != Y". + +from typing import Any, Dict, Generic, Optional, TypeVar, Union, final + + +def func1(a: Union[str, int]) -> int: + + if type(a) != str: + # This should generate an error because + # "a" is potentially a subclass of str. + return a + + # This should generate an error because + # "a" is provably type str at this point. + return a + + +def func2(a: Optional[str]) -> str: + + if type(a) == str: + return a + + # This should generate an error because + # "a" is provably type str at this point. + return a + + +def func3(a: Dict[str, Any]) -> str: + val = a.get("hello") + if type(val) == str: + return val + + return "none" + + +class A: + pass + + +class B(A): + pass + + +def func4(a: Union[str, A]): + if type(a) == B: + reveal_type(a, expected_text="B") + else: + reveal_type(a, expected_text="str | A") + + +T = TypeVar("T") + + +class C(Generic[T]): + def __init__(self, a: T): + self.a = a + + +class D: + pass + + +E = Union[C[T], D] + + +def func5(x: E[T]) -> None: + if type(x) == C: + reveal_type(x, expected_text="C[T@func5]") + + +@final +class AFinal: + pass + + +@final +class BFinal: + pass + + +def func6(val: Union[AFinal, BFinal]) -> None: + if type(val) == AFinal: + reveal_type(val, expected_text="AFinal") + else: + reveal_type(val, expected_text="BFinal") + + +def func7(val: Any): + if type(val) == int: + reveal_type(val, expected_text="int") + else: + reveal_type(val, expected_text="Any") + + reveal_type(val, expected_text="int | Any") + +class CParent: + ... + +class CChild(CParent): + ... + +_TC = TypeVar("_TC", bound=CParent) + +def func8(a: _TC, b: _TC) -> _TC: + if type(a) == CChild: + reveal_type(a, expected_text="CChild*") + return a + reveal_type(a, expected_text="CParent*") + return a diff --git a/packages/pyright-internal/src/tests/typeEvaluator1.test.ts b/packages/pyright-internal/src/tests/typeEvaluator1.test.ts index d05a4b091..95f97d1c0 100644 --- a/packages/pyright-internal/src/tests/typeEvaluator1.test.ts +++ b/packages/pyright-internal/src/tests/typeEvaluator1.test.ts @@ -297,6 +297,12 @@ test('TypeNarrowingTypeIs1', () => { TestUtils.validateResults(analysisResults, 3); }); +test('TypeNarrowingTypeEquals1', () => { + const analysisResults = TestUtils.typeAnalyzeSampleFiles(['typeNarrowingTypeEquals1.py']); + + TestUtils.validateResults(analysisResults, 3); +}); + test('TypeNarrowingIsNone1', () => { const analysisResults = TestUtils.typeAnalyzeSampleFiles(['typeNarrowingIsNone1.py']);