Added support for type guard patterns type(x) == T and type(x) != T. This addresses https://github.com/microsoft/pyright/issues/4719.

This commit is contained in:
Eric Traut 2023-03-01 23:13:01 -07:00
parent ac3487e03f
commit 81f21ee3d4
4 changed files with 119 additions and 2 deletions

View File

@ -171,6 +171,7 @@ In addition to assignment-based type narrowing, Pyright supports the following t
* `x is ...` and `x is not ...`
* `x == ...` and `x != ...`
* `type(x) is T` and `type(x) is not T`
* `type(x) == T` and `type(x) != T`
* `x is E` and `x is not E` (where E is a literal enum or bool)
* `x == L` and `x != L` (where L is an expression that evaluates to a literal type)
* `x.y is None` and `x.y is not None` (where x is a type that is distinguished by a field with a None)

View File

@ -191,8 +191,8 @@ export function getTypeNarrowingCallback(
}
}
// Look for "type(X) is Y" or "type(X) is not Y".
if (isOrIsNotOperator && testExpression.leftExpression.nodeType === ParseNodeType.Call) {
// Look for "type(X) is Y", "type(X) is not Y", "type(X) == Y" or "type(X) != Y".
if (testExpression.leftExpression.nodeType === ParseNodeType.Call) {
if (
testExpression.leftExpression.arguments.length === 1 &&
testExpression.leftExpression.arguments[0].argumentCategory === ArgumentCategory.Simple

View File

@ -0,0 +1,110 @@
# This sample exercises the type analyzer's type narrowing
# logic for tests of the form "type(X) == Y" or "type(X) != Y".
from typing import Any, Dict, Generic, Optional, TypeVar, Union, final
def func1(a: Union[str, int]) -> int:
if type(a) != str:
# This should generate an error because
# "a" is potentially a subclass of str.
return a
# This should generate an error because
# "a" is provably type str at this point.
return a
def func2(a: Optional[str]) -> str:
if type(a) == str:
return a
# This should generate an error because
# "a" is provably type str at this point.
return a
def func3(a: Dict[str, Any]) -> str:
val = a.get("hello")
if type(val) == str:
return val
return "none"
class A:
pass
class B(A):
pass
def func4(a: Union[str, A]):
if type(a) == B:
reveal_type(a, expected_text="B")
else:
reveal_type(a, expected_text="str | A")
T = TypeVar("T")
class C(Generic[T]):
def __init__(self, a: T):
self.a = a
class D:
pass
E = Union[C[T], D]
def func5(x: E[T]) -> None:
if type(x) == C:
reveal_type(x, expected_text="C[T@func5]")
@final
class AFinal:
pass
@final
class BFinal:
pass
def func6(val: Union[AFinal, BFinal]) -> None:
if type(val) == AFinal:
reveal_type(val, expected_text="AFinal")
else:
reveal_type(val, expected_text="BFinal")
def func7(val: Any):
if type(val) == int:
reveal_type(val, expected_text="int")
else:
reveal_type(val, expected_text="Any")
reveal_type(val, expected_text="int | Any")
class CParent:
...
class CChild(CParent):
...
_TC = TypeVar("_TC", bound=CParent)
def func8(a: _TC, b: _TC) -> _TC:
if type(a) == CChild:
reveal_type(a, expected_text="CChild*")
return a
reveal_type(a, expected_text="CParent*")
return a

View File

@ -297,6 +297,12 @@ test('TypeNarrowingTypeIs1', () => {
TestUtils.validateResults(analysisResults, 3);
});
test('TypeNarrowingTypeEquals1', () => {
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['typeNarrowingTypeEquals1.py']);
TestUtils.validateResults(analysisResults, 3);
});
test('TypeNarrowingIsNone1', () => {
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['typeNarrowingIsNone1.py']);